From a5bf78f3f0072fb7b6b932116e005e803665f029 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Wed, 19 Feb 2020 13:43:35 -0800 Subject: [PATCH] [BEAM-9344] Add support for bundle finalization to the Beam Java SDK. This change adds support for passing in the BundleFinalizer as an input parameter to start/process/finish and wires it through within the SDK harness. The only supported combination that can execute this is Python ULR and Dataflow using UW which is why all the validates runner test configurations have added the test category to the exclusions list. --- runners/apex/build.gradle | 1 + .../SplittableParDoNaiveBounded.java | 61 +- ...oundedSplittableProcessElementInvoker.java | 7 + .../beam/runners/core/SimpleDoFnRunner.java | 25 + .../SplittableParDoViaKeyedWorkItems.java | 70 +- runners/direct-java/build.gradle | 2 + runners/flink/flink_runner.gradle | 1 + .../flink/job-server/flink_job_server.gradle | 1 + runners/gearpump/build.gradle | 1 + .../google-cloud-dataflow-java/build.gradle | 1 + runners/jet/build.gradle | 2 + runners/portability/java/build.gradle | 5 - runners/samza/build.gradle | 1 + runners/spark/build.gradle | 2 + runners/spark/job-server/build.gradle | 1 + .../beam/sdk/testing/UsesBundleFinalizer.java | 26 + .../org/apache/beam/sdk/transforms/DoFn.java | 2 - .../beam/sdk/transforms/DoFnTester.java | 63 +- .../reflect/ByteBuddyDoFnInvokerFactory.java | 17 +- .../sdk/transforms/reflect/DoFnInvoker.java | 24 +- .../sdk/transforms/reflect/DoFnSignature.java | 52 +- .../transforms/reflect/DoFnSignatures.java | 96 ++- .../sdk/transforms/SplittableDoFnTest.java | 79 ++ .../transforms/reflect/DoFnInvokersTest.java | 22 +- .../reflect/DoFnSignaturesTest.java | 56 +- .../beam/fn/harness/BeamFnDataReadRunner.java | 4 +- .../fn/harness/BeamFnDataWriteRunner.java | 4 +- .../beam/fn/harness/BoundedSourceRunner.java | 4 +- .../beam/fn/harness/CombineRunners.java | 4 +- .../apache/beam/fn/harness/FlattenRunner.java | 4 +- .../beam/fn/harness/FnApiDoFnRunner.java | 757 +++++++++--------- .../org/apache/beam/fn/harness/FnHarness.java | 13 +- .../apache/beam/fn/harness/MapFnRunners.java | 4 +- .../fn/harness/PTransformRunnerFactory.java | 7 +- .../control/FinalizeBundleHandler.java | 161 ++++ .../harness/control/ProcessBundleHandler.java | 57 +- .../fn/harness/AssignWindowsRunnerTest.java | 3 +- .../fn/harness/BeamFnDataReadRunnerTest.java | 6 +- .../fn/harness/BeamFnDataWriteRunnerTest.java | 3 +- .../fn/harness/BoundedSourceRunnerTest.java | 3 +- .../beam/fn/harness/CombineRunnersTest.java | 4 + .../beam/fn/harness/FlattenRunnerTest.java | 6 +- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 24 +- .../beam/fn/harness/MapFnRunnersTest.java | 9 +- .../control/FinalizeBundleHandlerTest.java | 115 +++ .../control/ProcessBundleHandlerTest.java | 118 ++- 46 files changed, 1404 insertions(+), 524 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java create mode 100644 sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java create mode 100644 sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java diff --git a/runners/apex/build.gradle b/runners/apex/build.gradle index cb4d62126cd40..717996a4e0523 100644 --- a/runners/apex/build.gradle +++ b/runners/apex/build.gradle @@ -100,6 +100,7 @@ task validatesRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedPCollections' // TODO[BEAM-8304]: Support multiple side inputs with different coders. excludeCategories 'org.apache.beam.sdk.testing.UsesSideInputsWithDifferentCoders' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } // apex runner is run in embedded mode. Increase default HeapSize diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java index cf40ddedd400d..30a44da55e8f6 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java @@ -31,6 +31,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; @@ -132,10 +133,21 @@ public void setup() { @StartBundle public void startBundle(StartBundleContext c) { invoker.invokeStartBundle( - new DoFn.StartBundleContext() { + new BaseArgumentProvider() { @Override - public PipelineOptions getPipelineOptions() { - return c.getPipelineOptions(); + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + return new DoFn.StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + }; + } + + @Override + public String getErrorContext() { + return "SplittableParDoNaiveBounded/StartBundle"; } }); } @@ -174,23 +186,35 @@ public String getErrorContext() { @FinishBundle public void finishBundle(FinishBundleContext c) { invoker.invokeFinishBundle( - new DoFn.FinishBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return c.getPipelineOptions(); - } - + new BaseArgumentProvider() { @Override - public void output(@Nullable OutputT output, Instant timestamp, BoundedWindow window) { - throw new UnsupportedOperationException( - "Output from FinishBundle for SDF is not supported"); + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return new DoFn.FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public void output( + @Nullable OutputT output, Instant timestamp, BoundedWindow window) { + throw new UnsupportedOperationException( + "Output from FinishBundle for SDF is not supported"); + } + + @Override + public void output( + TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + throw new UnsupportedOperationException( + "Output from FinishBundle for SDF is not supported"); + } + }; } @Override - public void output( - TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - throw new UnsupportedOperationException( - "Output from FinishBundle for SDF is not supported"); + public String getErrorContext() { + return "SplittableParDoNaiveBounded/StartBundle"; } }); } @@ -317,6 +341,11 @@ public OutputReceiver getRowReceiver(TupleTag tag) { }; } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException(); + } + @Override public Object restriction() { return tracker.currentRestriction(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 21bdf31d991de..010eb11a9425f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -169,6 +170,12 @@ public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { return DoFnOutputReceivers.windowedMultiReceiver(processContext, null); } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Not supported in non-portable SplittableDoFn"); + } + @Override public RestrictionTracker restrictionTracker() { return processContext.tracker; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index e4362ebed2072..71efa128209c6 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.state.TimerSpec; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; @@ -391,6 +392,12 @@ public TimerMap timerFamily(String tagId) { throw new UnsupportedOperationException( "Cannot access timer family outside of @ProcessElement and @OnTimer methods"); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** B A concrete implementation of {@link DoFn.FinishBundleContext}. */ @@ -538,6 +545,12 @@ public void output(OutputT output, Instant timestamp, BoundedWindow window) { public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { outputWindowedValue(tag, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** @@ -791,6 +804,12 @@ public TimerMap timerFamily(String timerFamilyId) { throw new RuntimeException(e); } } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** @@ -1014,6 +1033,12 @@ public void output(TupleTag tag, T output) { public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { outputWindowedValue(tag, WindowedValue.of(output, timestamp, window(), PaneInfo.NO_FIRING)); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } private class TimerInternalsTimer implements Timer { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index 735e864e8e481..dc795a55dafef 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -36,6 +36,9 @@ import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; +import org.apache.beam.sdk.transforms.DoFn.ProcessContext; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; @@ -414,39 +417,62 @@ public String getErrorContext() { TimerInternals.TimerData.of(stateNamespace, wakeupTime, TimeDomain.PROCESSING_TIME)); } - private DoFn.StartBundleContext wrapContextAsStartBundle( + private DoFnInvoker.ArgumentProvider wrapContextAsStartBundle( final StartBundleContext baseContext) { - return fn.new StartBundleContext() { + return new BaseArgumentProvider() { @Override - public PipelineOptions getPipelineOptions() { - return baseContext.getPipelineOptions(); + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + return fn.new StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return baseContext.getPipelineOptions(); + } + }; } - }; - } - private DoFn.FinishBundleContext wrapContextAsFinishBundle( - final FinishBundleContext baseContext) { - return fn.new FinishBundleContext() { @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutput(); + public String getErrorContext() { + return "SplittableParDoViaKeyedWorkItems/StartBundle"; } + }; + } + private DoFnInvoker.ArgumentProvider wrapContextAsFinishBundle( + final FinishBundleContext baseContext) { + return new BaseArgumentProvider() { @Override - public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutput(); + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return fn.new FinishBundleContext() { + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + throwUnsupportedOutput(); + } + + @Override + public void output( + TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + throwUnsupportedOutput(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return baseContext.getPipelineOptions(); + } + + private void throwUnsupportedOutput() { + throw new UnsupportedOperationException( + String.format( + "Splittable DoFn can only output from @%s", + ProcessElement.class.getSimpleName())); + } + }; } @Override - public PipelineOptions getPipelineOptions() { - return baseContext.getPipelineOptions(); - } - - private void throwUnsupportedOutput() { - throw new UnsupportedOperationException( - String.format( - "Splittable DoFn can only output from @%s", - ProcessElement.class.getSimpleName())); + public String getErrorContext() { + return "SplittableParDoViaKeyedWorkItems/FinishBundle"; } }; } diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index 6d4652872e427..98dca8e8efe9d 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -110,6 +110,7 @@ task needsRunnerTests(type: Test) { // MetricsPusher isn't implemented in direct runner excludeCategories "org.apache.beam.sdk.testing.UsesMetricsPusher" excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } @@ -134,6 +135,7 @@ task validatesRunner(type: Test) { excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index ce8cf191af8d4..54eb4c54e133b 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -202,6 +202,7 @@ def createValidatesRunnerTask(Map m) { excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above100MB' excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesSystemMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' if (config.streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index 168860d38c4ff..691163902f3c0 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -149,6 +149,7 @@ def portableValidatesRunnerTask(String name, Boolean streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' if (streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime' excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' diff --git a/runners/gearpump/build.gradle b/runners/gearpump/build.gradle index f9bb2594f7695..57d574dec25df 100644 --- a/runners/gearpump/build.gradle +++ b/runners/gearpump/build.gradle @@ -85,6 +85,7 @@ task validatesRunnerStreaming(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 3e3b1f7a632ee..5e258960e4e08 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -142,6 +142,7 @@ def commonExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesTestStream', 'org.apache.beam.sdk.testing.UsesParDoLifecycle', 'org.apache.beam.sdk.testing.UsesMetricsPusher', + 'org.apache.beam.sdk.testing.UsesBundleFinalizer' ] def fnApiWorkerExcludeCategories = [ diff --git a/runners/jet/build.gradle b/runners/jet/build.gradle index f38835c0df4e0..491f3ceb73782 100644 --- a/runners/jet/build.gradle +++ b/runners/jet/build.gradle @@ -74,6 +74,7 @@ task validatesRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' //impulse doesn't cooperate properly with Jet when multiple cluster members are used exclude '**/SplittableDoFnTest.class' //Splittable DoFn functionality not yet in the runner + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } maxHeapSize = '4g' @@ -97,6 +98,7 @@ task needsRunnerTests(type: Test) { useJUnit { includeCategories "org.apache.beam.sdk.testing.NeedsRunner" excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index 9c529a60511bc..b241ec0024606 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -23,11 +23,6 @@ description = "Apache Beam :: Runners :: Portability :: Java" ext.summary = """A Java implementation of the Beam Model which utilizes the portability framework to execute user-definied functions.""" - -configurations { - validatesRunner -} - dependencies { compile library.java.vendored_guava_26_0_jre compile library.java.hamcrest_library diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle index a041eabec9ddd..ada72bde85dd3 100644 --- a/runners/samza/build.gradle +++ b/runners/samza/build.gradle @@ -87,6 +87,7 @@ task validatesRunner(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle index d5249d51dc4a1..2325e2b238d4c 100644 --- a/runners/spark/build.gradle +++ b/runners/spark/build.gradle @@ -148,6 +148,7 @@ task validatesRunnerBatch(type: Test) { // Portability excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } jvmArgs '-Xmx3g' } @@ -214,6 +215,7 @@ task validatesStructuredStreamingRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } filter { // Combine with context not implemented diff --git a/runners/spark/job-server/build.gradle b/runners/spark/job-server/build.gradle index 0ff64883ec573..f65b95fff15d4 100644 --- a/runners/spark/job-server/build.gradle +++ b/runners/spark/job-server/build.gradle @@ -113,6 +113,7 @@ def portableValidatesRunnerTask(String name) { excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs' excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' }, ) } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java new file mode 100644 index 0000000000000..5cd2c5b876ee6 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; + +/** + * Category tag for validation tests which use {@link BundleFinalizer}. Tests tagged with {@link + * UsesBundleFinalizer} should be run for runners which support bundle finalization. + */ +public interface UsesBundleFinalizer {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 34937fb1af383..c43840f2d972d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -599,7 +599,6 @@ public interface MultiOutputReceiver { * Finalize Bundles for further details. * */ - // TODO: Add support for bundle finalization parameter. @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) @@ -1165,7 +1164,6 @@ public void populateDisplayData(DisplayData.Builder builder) {} * consumers without waiting for finalization to succeed. For pipelines that are sensitive to * duplicate messages, they must perform output deduplication in the pipeline. */ - // TODO: Add support for a deduplication PTransform. @Experimental(Kind.SPLITTABLE_DO_FN) public interface BundleFinalizer { /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index fb66a172c333c..ba6bc17a11b86 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -31,15 +31,20 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.Materializations.MultimapView; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; @@ -325,6 +330,12 @@ public Timer timer(String timerId) { public TimerMap timerFamily(String tagId) { throw new UnsupportedOperationException("DoFnTester doesn't support timerFamily yet"); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "DoFnTester doesn't support bundleFinalizer yet"); + } }); } catch (UserCodeException e) { unwrapUserCodeException(e); @@ -459,38 +470,48 @@ public TupleTag getMainOutputTag() { return mainOutputTag; } - private class TestStartBundleContext extends DoFn.StartBundleContext { - - private TestStartBundleContext() { - fn.super(); + private class TestStartBundleContext extends BaseArgumentProvider { + @Override + public StartBundleContext startBundleContext(DoFn doFn) { + return fn.new StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + }; } @Override - public PipelineOptions getPipelineOptions() { - return options; + public String getErrorContext() { + return "DoFnTester/StartBundle"; } } - private class TestFinishBundleContext extends DoFn.FinishBundleContext { - - private TestFinishBundleContext() { - fn.super(); - } - + private class TestFinishBundleContext extends BaseArgumentProvider { @Override - public PipelineOptions getPipelineOptions() { - return options; - } + public FinishBundleContext finishBundleContext(DoFn doFn) { + return fn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return options; + } - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - output(mainOutputTag, output, timestamp, window); + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + output(mainOutputTag, output, timestamp, window); + } + + @Override + public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + getMutableOutput(tag) + .add(ValueInSingleWindow.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + }; } @Override - public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - getMutableOutput(tag) - .add(ValueInSingleWindow.of(output, timestamp, window, PaneInfo.NO_FIRING)); + public String getErrorContext() { + return "DoFnTester/FinishBundle"; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 5649ff6e1f2bb..503942c0e9c19 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter; @@ -107,6 +108,7 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory { public static final String ELEMENT_PARAMETER_METHOD = "element"; public static final String SCHEMA_ELEMENT_PARAMETER_METHOD = "schemaElement"; public static final String TIMESTAMP_PARAMETER_METHOD = "timestamp"; + public static final String BUNDLE_FINALIZER_PARAMETER_METHOD = "bundleFinalizer"; public static final String OUTPUT_ROW_RECEIVER_METHOD = "outputRowReceiver"; public static final String TIME_DOMAIN_PARAMETER_METHOD = "timeDomain"; public static final String OUTPUT_PARAMETER_METHOD = "outputReceiver"; @@ -366,9 +368,11 @@ public static double invokeGetSize( // public invokeStartBundle(Context c) { delegate.<@StartBundle>(c); } // ... etc ... .method(ElementMatchers.named("invokeStartBundle")) - .intercept(delegateOrNoop(clazzDescription, signature.startBundle())) + .intercept( + delegateMethodWithExtraParametersOrNoop(clazzDescription, signature.startBundle())) .method(ElementMatchers.named("invokeFinishBundle")) - .intercept(delegateOrNoop(clazzDescription, signature.finishBundle())) + .intercept( + delegateMethodWithExtraParametersOrNoop(clazzDescription, signature.finishBundle())) .method(ElementMatchers.named("invokeSetup")) .intercept(delegateOrNoop(clazzDescription, signature.setup())) .method(ElementMatchers.named("invokeTeardown")) @@ -779,6 +783,15 @@ public StackManipulation dispatch(TimestampParameter p) { TIMESTAMP_PARAMETER_METHOD, DoFn.class))); } + @Override + public StackManipulation dispatch(BundleFinalizerParameter p) { + return new StackManipulation.Compound( + pushDelegate, + MethodInvocation.invoke( + getExtraContextFactoryMethodDescription( + BUNDLE_FINALIZER_PARAMETER_METHOD, DoFn.class))); + } + @Override public StackManipulation dispatch(TimeDomainParameter p) { return new StackManipulation.Compound( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index e4ac2d8ccd849..d52872e53714e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.FinishBundle; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -53,10 +54,10 @@ public interface DoFnInvoker { void invokeSetup(); /** Invoke the {@link DoFn.StartBundle} method on the bound {@link DoFn}. */ - void invokeStartBundle(DoFn.StartBundleContext c); + void invokeStartBundle(ArgumentProvider arguments); /** Invoke the {@link DoFn.FinishBundle} method on the bound {@link DoFn}. */ - void invokeFinishBundle(DoFn.FinishBundleContext c); + void invokeFinishBundle(ArgumentProvider arguments); /** Invoke the {@link DoFn.Teardown} method on the bound {@link DoFn}. */ void invokeTeardown(); @@ -168,9 +169,15 @@ interface ArgumentProvider { /** Provide a {@link OutputReceiver} for outputting rows to the default output. */ OutputReceiver outputRowReceiver(DoFn doFn); - /** Provide a {@link MultiOutputReceiver} for outputing to the default output. */ + /** Provide a {@link MultiOutputReceiver} for outputting to the default output. */ MultiOutputReceiver taggedOutputReceiver(DoFn doFn); + /** + * Provide a {@link BundleFinalizer} for being able to register a callback after the bundle has + * been successfully persisted by the runner. + */ + BundleFinalizer bundleFinalizer(); + /** * If this is a splittable {@link DoFn}, returns the associated restriction with the current * call. @@ -330,6 +337,12 @@ public Timer timer(String timerId) { String.format("RestrictionTracker unsupported in %s", getErrorContext())); } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + String.format("BundleFinalizer unsupported in %s", getErrorContext())); + } + /** * Return a human readable representation of the current call context to be used during error * reporting. @@ -455,6 +468,11 @@ public String timerId(DoFn doFn) { return delegate.timerId(doFn); } + @Override + public BundleFinalizer bundleFinalizer() { + return delegate.bundleFinalizer(); + } + @Override public String getErrorContext() { return errorContext; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index b472cc31758fc..b675596c14350 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -271,6 +271,8 @@ public ResultT match(Cases cases) { return cases.dispatch((TimerFamilyParameter) this); } else if (this instanceof TimerIdParameter) { return cases.dispatch((TimerIdParameter) this); + } else if (this instanceof BundleFinalizerParameter) { + return cases.dispatch((BundleFinalizerParameter) this); } else { throw new IllegalStateException( String.format( @@ -321,6 +323,8 @@ public interface Cases { ResultT dispatch(TimerIdParameter p); + ResultT dispatch(BundleFinalizerParameter p); + /** A base class for a visitor with a default method for cases it is not interested in. */ abstract class WithDefault implements Cases { @@ -401,6 +405,11 @@ public ResultT dispatch(RestrictionTrackerParameter p) { return dispatchDefault(p); } + @Override + public ResultT dispatch(BundleFinalizerParameter p) { + return dispatchDefault(p); + } + @Override public ResultT dispatch(StateParameter p) { return dispatchDefault(p); @@ -449,12 +458,29 @@ public ResultT dispatch(TimerFamilyParameter p) { new AutoValue_DoFnSignature_Parameter_TaggedOutputReceiverParameter(); private static final PipelineOptionsParameter PIPELINE_OPTIONS_PARAMETER = new AutoValue_DoFnSignature_Parameter_PipelineOptionsParameter(); + private static final BundleFinalizerParameter BUNDLE_FINALIZER_PARAMETER = + new AutoValue_DoFnSignature_Parameter_BundleFinalizerParameter(); /** Returns a {@link ProcessContextParameter}. */ public static ProcessContextParameter processContext() { return PROCESS_CONTEXT_PARAMETER; } + /** Returns a {@link StartBundleContextParameter}. */ + public static StartBundleContextParameter startBundleContext() { + return START_BUNDLE_CONTEXT_PARAMETER; + } + + /** Returns a {@link FinishBundleContextParameter}. */ + public static FinishBundleContextParameter finishBundleContext() { + return FINISH_BUNDLE_CONTEXT_PARAMETER; + } + + /** Returns a {@link BundleFinalizerParameter}. */ + public static BundleFinalizerParameter bundleFinalizer() { + return BUNDLE_FINALIZER_PARAMETER; + } + public static ElementParameter elementParameter(TypeDescriptor elementT) { return new AutoValue_DoFnSignature_Parameter_ElementParameter(elementT); } @@ -574,6 +600,16 @@ public abstract static class ProcessContextParameter extends Parameter { ProcessContextParameter() {} } + /** + * Descriptor for a {@link Parameter} of type {@link DoFn.BundleFinalizer}. + * + *

All such descriptors are equal. + */ + @AutoValue + public abstract static class BundleFinalizerParameter extends Parameter { + BundleFinalizerParameter() {} + } + /** * Descriptor for a {@link Parameter} of type {@link DoFn.Element}. * @@ -1041,13 +1077,23 @@ static TimerFamilyDeclaration create(String id, Field field) { /** Describes a {@link DoFn.StartBundle} or {@link DoFn.FinishBundle} method. */ @AutoValue - public abstract static class BundleMethod implements DoFnMethod { + public abstract static class BundleMethod implements MethodWithExtraParameters { /** The annotated method itself. */ @Override public abstract Method targetMethod(); - static BundleMethod create(Method targetMethod) { - return new AutoValue_DoFnSignature_BundleMethod(targetMethod); + /** Types of optional parameters of the annotated method, in the order they appear. */ + @Override + public abstract List extraParameters(); + + /** The type of window expected by this method, if any. */ + @Override + @Nullable + public abstract TypeDescriptor windowT(); + + static BundleMethod create(Method targetMethod, List extraParameters) { + /* start bundle/finish bundle currently do not get invoked on a per window basis and can't accept a BoundedWindow parameter */ + return new AutoValue_DoFnSignature_BundleMethod(targetMethod, extraParameters, null); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 007663c587d5e..d34baba15ee5b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -112,7 +112,8 @@ private DoFnSignatures() {} Parameter.TimerParameter.class, Parameter.StateParameter.class, Parameter.SideInputParameter.class, - Parameter.TimerFamilyParameter.class); + Parameter.TimerFamilyParameter.class, + Parameter.BundleFinalizerParameter.class); private static final ImmutableList> ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = @@ -126,7 +127,16 @@ private DoFnSignatures() {} Parameter.TaggedOutputReceiverParameter.class, Parameter.ProcessContextParameter.class, Parameter.RestrictionTrackerParameter.class, - Parameter.SideInputParameter.class); + Parameter.SideInputParameter.class, + Parameter.BundleFinalizerParameter.class); + + private static final ImmutableList> ALLOWED_START_BUNDLE_PARAMETERS = + ImmutableList.of( + Parameter.StartBundleContextParameter.class, Parameter.BundleFinalizerParameter.class); + + private static final ImmutableList> ALLOWED_FINISH_BUNDLE_PARAMETERS = + ImmutableList.of( + Parameter.FinishBundleContextParameter.class, Parameter.BundleFinalizerParameter.class); private static final ImmutableList> ALLOWED_ON_TIMER_PARAMETERS = ImmutableList.of( @@ -550,14 +560,16 @@ private static DoFnSignature parseSignature(Class> fnClass) if (startBundleMethod != null) { ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod); signatureBuilder.setStartBundle( - analyzeStartBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT)); + analyzeStartBundleMethod( + startBundleErrors, fnT, startBundleMethod, inputT, outputT, fnContext)); } if (finishBundleMethod != null) { ErrorReporter finishBundleErrors = errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod); signatureBuilder.setFinishBundle( - analyzeFinishBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT)); + analyzeFinishBundleMethod( + finishBundleErrors, fnT, finishBundleMethod, inputT, outputT, fnContext)); } if (setupMethod != null) { @@ -1085,6 +1097,8 @@ private static Parameter analyzeExtraParameter( TypeDescriptor outputT) { TypeDescriptor expectedProcessContextT = doFnProcessContextTypeOf(inputT, outputT); + TypeDescriptor expectedStartBundleContextT = doFnStartBundleContextTypeOf(inputT, outputT); + TypeDescriptor expectedFinishBundleContextT = doFnFinishBundleContextTypeOf(inputT, outputT); TypeDescriptor expectedOnTimerContextT = doFnOnTimerContextTypeOf(inputT, outputT); TypeDescriptor paramT = param.getType(); @@ -1115,12 +1129,26 @@ private static Parameter analyzeExtraParameter( return Parameter.sideInputParameter(paramT, sideInputId); } else if (rawType.equals(PaneInfo.class)) { return Parameter.paneInfoParameter(); + } else if (rawType.equals(DoFn.BundleFinalizer.class)) { + return Parameter.bundleFinalizer(); } else if (rawType.equals(DoFn.ProcessContext.class)) { paramErrors.checkArgument( paramT.equals(expectedProcessContextT), "ProcessContext argument must have type %s", format(expectedProcessContextT)); return Parameter.processContext(); + } else if (rawType.equals(DoFn.StartBundleContext.class)) { + paramErrors.checkArgument( + paramT.equals(expectedStartBundleContextT), + "StartBundleContext argument must have type %s", + format(expectedProcessContextT)); + return Parameter.startBundleContext(); + } else if (rawType.equals(DoFn.FinishBundleContext.class)) { + paramErrors.checkArgument( + paramT.equals(expectedFinishBundleContextT), + "FinishBundleContext argument must have type %s", + format(expectedProcessContextT)); + return Parameter.finishBundleContext(); } else if (rawType.equals(DoFn.OnTimerContext.class)) { paramErrors.checkArgument( paramT.equals(expectedOnTimerContextT), @@ -1367,16 +1395,30 @@ static DoFnSignature.BundleMethod analyzeStartBundleMethod( TypeDescriptor> fnT, Method m, TypeDescriptor inputT, - TypeDescriptor outputT) { + TypeDescriptor outputT, + FnAnalysisContext fnContext) { errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); - TypeDescriptor expectedContextT = doFnStartBundleContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); - errors.checkArgument( - params.length == 0 - || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), - "Must take a single argument of type %s", - format(expectedContextT)); - return DoFnSignature.BundleMethod.create(m); + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + for (int i = 0; i < params.length; ++i) { + Parameter extraParam = + analyzeExtraParameter( + errors, + fnContext, + methodContext, + fnT, + ParameterDescription.of( + m, i, fnT.resolveType(params[i]), Arrays.asList(m.getParameterAnnotations()[i])), + inputT, + outputT); + methodContext.addParameter(extraParam); + } + + for (Parameter parameter : methodContext.getExtraParameters()) { + checkParameterOneOf(errors, parameter, ALLOWED_START_BUNDLE_PARAMETERS); + } + + return DoFnSignature.BundleMethod.create(m, methodContext.extraParameters); } @VisibleForTesting @@ -1385,16 +1427,30 @@ static DoFnSignature.BundleMethod analyzeFinishBundleMethod( TypeDescriptor> fnT, Method m, TypeDescriptor inputT, - TypeDescriptor outputT) { + TypeDescriptor outputT, + FnAnalysisContext fnContext) { errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); - TypeDescriptor expectedContextT = doFnFinishBundleContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); - errors.checkArgument( - params.length == 0 - || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), - "Must take a single argument of type %s", - format(expectedContextT)); - return DoFnSignature.BundleMethod.create(m); + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + for (int i = 0; i < params.length; ++i) { + Parameter extraParam = + analyzeExtraParameter( + errors, + fnContext, + methodContext, + fnT, + ParameterDescription.of( + m, i, fnT.resolveType(params[i]), Arrays.asList(m.getParameterAnnotations()[i])), + inputT, + outputT); + methodContext.addParameter(extraParam); + } + + for (Parameter parameter : methodContext.getExtraParameters()) { + checkParameterOneOf(errors, parameter, ALLOWED_FINISH_BUNDLE_PARAMETERS); + } + + return DoFnSignature.BundleMethod.create(m, methodContext.extraParameters); } private static DoFnSignature.LifecycleMethod analyzeLifecycleMethod( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index e22f72920deba..14763f22fd576 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.transforms; +import static java.lang.Thread.sleep; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; @@ -28,6 +29,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testing.UsesBoundedSplittableParDo; +import org.apache.beam.sdk.testing.UsesBundleFinalizer; import org.apache.beam.sdk.testing.UsesParDoLifecycle; import org.apache.beam.sdk.testing.UsesSideInputs; import org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs; @@ -47,7 +50,9 @@ import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.transforms.splittabledofn.SplitResult; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.Never; @@ -807,6 +812,80 @@ public OffsetRange getInitialRestriction() { } } + /** + * While the finalization callback hasn't been invoked, this DoFn will keep requesting + * finalization, wait one second and then checkpoint upto MAX_ATTEMPTS amount of times. Once the + * callback has been invoked, the DoFn will output the element and stop. + */ + public static class BundleFinalizingSplittableDoFn extends DoFn { + private static final long MAX_ATTEMPTS = 300; + private static final AtomicBoolean wasFinalized = new AtomicBoolean(); + + @NewTracker + public RestrictionTracker newTracker(@Restriction OffsetRange restriction) { + // Use a modified OffsetRangeTracker with disabled splitting to prevent + // parallelization of execution. + return new OffsetRangeTracker(restriction) { + @Override + public SplitResult trySplit(double fractionOfRemainder) { + return null; + } + }; + } + + @ProcessElement + public ProcessContinuation process( + @Element String element, + OutputReceiver receiver, + RestrictionTracker tracker, + BundleFinalizer bundleFinalizer) + throws InterruptedException { + if (wasFinalized.get()) { + // Claim beyond the end now that we know we have been finalized. + tracker.tryClaim(Long.MAX_VALUE); + receiver.output(element); + return stop(); + } + if (tracker.tryClaim(tracker.currentRestriction().getFrom() + 1)) { + bundleFinalizer.afterBundleCommit( + Instant.now().plus(Duration.standardSeconds(MAX_ATTEMPTS)), + () -> wasFinalized.set(true)); + // We sleep here instead of setting a resume time since the resume time doesn't need to + // be honored. + sleep(1000L); // 1 second + return resume(); + } + return stop(); + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction() { + return new OffsetRange(0, MAX_ATTEMPTS); + } + } + + @Test + @Category({ValidatesRunner.class, UsesBoundedSplittableParDo.class, UsesBundleFinalizer.class}) + public void testBundleFinalizationOccursOnBoundedSplittableDoFn() throws Exception { + @BoundedPerElement + class BoundedBundleFinalizingSplittableDoFn extends BundleFinalizingSplittableDoFn {} + Pipeline p = TestPipeline.create(); + PCollection foo = p.apply(Create.of("foo")); + PCollection res = foo.apply(ParDo.of(new BoundedBundleFinalizingSplittableDoFn())); + PAssert.that(res).containsInAnyOrder("foo"); + } + + @Test + @Category({ValidatesRunner.class, UsesUnboundedSplittableParDo.class, UsesBundleFinalizer.class}) + public void testBundleFinalizationOccursOnUnboundedSplittableDoFn() throws Exception { + @BoundedPerElement + class UnboundedBundleFinalizingSplittableDoFn extends BundleFinalizingSplittableDoFn {} + Pipeline p = TestPipeline.create(); + PCollection foo = p.apply(Create.of("foo")); + PCollection res = foo.apply(ParDo.of(new UnboundedBundleFinalizingSplittableDoFn())); + PAssert.that(res).containsInAnyOrder("foo"); + } + // TODO (https://issues.apache.org/jira/browse/BEAM-988): Test that Splittable DoFn // emits output immediately (i.e. has a pass-through trigger) regardless of input's // windowing/triggering strategy. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 8de6785af950b..902d982b99af6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; @@ -70,7 +71,6 @@ import org.mockito.AdditionalAnswers; import org.mockito.Matchers; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; /** Tests for {@link DoFnInvokers}. */ @@ -342,6 +342,10 @@ public SomeRestrictionTracker newTracker(@Restriction SomeRestriction restrictio @Test public void testDoFnWithStartBundleSetupTeardown() throws Exception { + when(mockArgumentProvider.startBundleContext(any(DoFn.class))) + .thenReturn(mockStartBundleContext); + when(mockArgumentProvider.finishBundleContext(any(DoFn.class))) + .thenReturn(mockFinishBundleContext); class MockFn extends DoFn { @ProcessElement public void processElement(ProcessContext c) {} @@ -362,8 +366,8 @@ public void after() {} MockFn fn = mock(MockFn.class); DoFnInvoker invoker = DoFnInvokers.invokerFor(fn); invoker.invokeSetup(); - invoker.invokeStartBundle(mockStartBundleContext); - invoker.invokeFinishBundle(mockFinishBundleContext); + invoker.invokeStartBundle(mockArgumentProvider); + invoker.invokeFinishBundle(mockArgumentProvider); invoker.invokeTeardown(); verify(fn).before(); verify(fn).startBundle(mockStartBundleContext); @@ -450,7 +454,7 @@ public void splitRestriction( } })) .when(fn) - .splitRestriction(eq(mockElement), same(restriction), Mockito.any()); + .splitRestriction(eq(mockElement), same(restriction), any()); when(fn.newTracker(restriction)).thenReturn(tracker); when(fn.processElement(mockProcessContext, tracker)).thenReturn(resume()); @@ -832,6 +836,9 @@ public DoFn.ProcessContext processContext(DoFn doFn) { @Test public void testStartBundleException() throws Exception { + DoFnInvoker.ArgumentProvider mockArguments = + mock(DoFnInvoker.ArgumentProvider.class); + when(mockArguments.startBundleContext(any(DoFn.class))).thenReturn(null); DoFnInvoker invoker = DoFnInvokers.invokerFor( new DoFn() { @@ -845,11 +852,14 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) {} }); thrown.expect(UserCodeException.class); thrown.expectMessage("bogus"); - invoker.invokeStartBundle(null); + invoker.invokeStartBundle(mockArguments); } @Test public void testFinishBundleException() throws Exception { + DoFnInvoker.ArgumentProvider mockArguments = + mock(DoFnInvoker.ArgumentProvider.class); + when(mockArguments.finishBundleContext(any(DoFn.class))).thenReturn(null); DoFnInvoker invoker = DoFnInvokers.invokerFor( new DoFn() { @@ -863,7 +873,7 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) {} }); thrown.expect(UserCodeException.class); thrown.expectMessage("bogus"); - invoker.invokeFinishBundle(null); + invoker.invokeFinishBundle(mockArguments); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index dc0716d5f951d..24e581b6893dd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -55,19 +55,23 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.PaneInfoParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.PipelineOptionsParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StartBundleContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TaggedOutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimeDomainParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimestampParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures.FnAnalysisContext; import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -121,10 +125,11 @@ public void process( OutputReceiver receiver, PipelineOptions options, @SideInput("tag1") String input1, - @SideInput("tag2") Integer input2) {} + @SideInput("tag2") Integer input2, + BundleFinalizer bundleFinalizer) {} }.getClass()); - assertThat(sig.processElement().extraParameters().size(), equalTo(8)); + assertThat(sig.processElement().extraParameters().size(), equalTo(9)); assertThat(sig.processElement().extraParameters().get(0), instanceOf(ElementParameter.class)); assertThat(sig.processElement().extraParameters().get(1), instanceOf(TimestampParameter.class)); assertThat(sig.processElement().extraParameters().get(2), instanceOf(WindowParameter.class)); @@ -135,6 +140,8 @@ public void process( sig.processElement().extraParameters().get(5), instanceOf(PipelineOptionsParameter.class)); assertThat(sig.processElement().extraParameters().get(6), instanceOf(SideInputParameter.class)); assertThat(sig.processElement().extraParameters().get(7), instanceOf(SideInputParameter.class)); + assertThat( + sig.processElement().extraParameters().get(8), instanceOf(BundleFinalizerParameter.class)); } @Test @@ -276,8 +283,7 @@ public void process(ProcessContext c) {} @Test public void testBadExtraContext() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Must take a single argument of type DoFn.StartBundleContext"); + thrown.expectMessage("int is not a valid context parameter"); DoFnSignatures.analyzeStartBundleMethod( errors(), @@ -286,7 +292,8 @@ public void testBadExtraContext() throws Exception { void method(DoFn.StartBundleContext c, int n) {} }.getMethod(), TypeDescriptor.of(Integer.class), - TypeDescriptor.of(String.class)); + TypeDescriptor.of(String.class), + FnAnalysisContext.create()); } @Test @@ -345,6 +352,25 @@ void startBundle() {} }.getClass()); } + @Test + public void testStartBundleWithAllParameters() throws Exception { + DoFnSignature sig = + DoFnSignatures.getSignature( + new DoFn() { + @ProcessElement + public void processElement() {} + + @StartBundle + public void startBundle( + StartBundleContext context, BundleFinalizer bundleFinalizer) {} + }.getClass()); + assertThat(sig.startBundle().extraParameters().size(), equalTo(2)); + assertThat( + sig.startBundle().extraParameters().get(0), instanceOf(StartBundleContextParameter.class)); + assertThat( + sig.startBundle().extraParameters().get(1), instanceOf(BundleFinalizerParameter.class)); + } + @Test public void testPrivateFinishBundle() throws Exception { thrown.expect(IllegalArgumentException.class); @@ -361,6 +387,26 @@ void finishBundle() {} }.getClass()); } + @Test + public void testFinishBundleWithAllParameters() throws Exception { + DoFnSignature sig = + DoFnSignatures.getSignature( + new DoFn() { + @ProcessElement + public void processElement() {} + + @FinishBundle + public void finishBundle( + FinishBundleContext context, BundleFinalizer bundleFinalizer) {} + }.getClass()); + assertThat(sig.finishBundle().extraParameters().size(), equalTo(2)); + assertThat( + sig.finishBundle().extraParameters().get(0), + instanceOf(FinishBundleContextParameter.class)); + assertThat( + sig.finishBundle().extraParameters().get(1), instanceOf(BundleFinalizerParameter.class)); + } + @Test public void testTimerIdWithWrongType() throws Exception { thrown.expect(IllegalArgumentException.class); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java index e1756dc5a52dd..650d86b648e24 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java @@ -50,6 +50,7 @@ import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints; @@ -95,7 +96,8 @@ public BeamFnDataReadRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { FnDataReceiver> consumer = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java index 8cb9d33b7529b..86deadb3286bf 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; @@ -89,7 +90,8 @@ public BeamFnDataWriteRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { BeamFnDataWriteRunner runner = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java index a06c002ac6dbc..4429bb81552c1 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException; @@ -83,7 +84,8 @@ public BoundedSourceRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { ImmutableList.Builder>> consumers = ImmutableList.builder(); for (String pCollectionId : pTransform.getOutputsMap().values()) { consumers.add(pCollectionConsumerRegistry.getMultiplexingConsumer(pCollectionId)); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java index 6bd16bed86971..feb6b6b76faba 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java @@ -40,6 +40,7 @@ import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; @@ -127,7 +128,8 @@ public PrecombineRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Get objects needed to create the runner. RehydratedComponents rehydratedComponents = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java index 702751c6bbc6c..2307b5b2e7fbe 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -68,7 +69,8 @@ public FlattenRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Give each input a MultiplexingFnDataReceiver to all outputs of the flatten. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 907ec74b4ff3d..3ca5d4c482094 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -64,6 +64,7 @@ import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; @@ -71,6 +72,7 @@ import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.DelegatingArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; @@ -127,169 +129,6 @@ public Map getPTransformRunnerFactories() { } } - static class Context { - final PipelineOptions pipelineOptions; - final BeamFnStateClient beamFnStateClient; - final String ptransformId; - final PTransform pTransform; - final Supplier processBundleInstructionId; - final RehydratedComponents rehydratedComponents; - final DoFn doFn; - final DoFnSignature doFnSignature; - final TupleTag mainOutputTag; - final Coder inputCoder; - final SchemaCoder schemaCoder; - final Coder keyCoder; - final SchemaCoder mainOutputSchemaCoder; - final Coder windowCoder; - final WindowingStrategy windowingStrategy; - final Map, SideInputSpec> tagToSideInputSpecMap; - Map, Coder> outputCoders; - final ParDoPayload parDoPayload; - final ListMultimap>> localNameToConsumer; - final BundleSplitListener splitListener; - - Context( - PipelineOptions pipelineOptions, - BeamFnStateClient beamFnStateClient, - String ptransformId, - PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Map coders, - Map windowingStrategies, - PCollectionConsumerRegistry pCollectionConsumerRegistry, - BundleSplitListener splitListener) { - this.pipelineOptions = pipelineOptions; - this.beamFnStateClient = beamFnStateClient; - this.ptransformId = ptransformId; - this.pTransform = pTransform; - this.processBundleInstructionId = processBundleInstructionId; - ImmutableMap.Builder, SideInputSpec> tagToSideInputSpecMapBuilder = - ImmutableMap.builder(); - try { - rehydratedComponents = - RehydratedComponents.forComponents( - RunnerApi.Components.newBuilder() - .putAllCoders(coders) - .putAllPcollections(pCollections) - .putAllWindowingStrategies(windowingStrategies) - .build()) - .withPipeline(Pipeline.create()); - parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload()); - doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload); - doFnSignature = DoFnSignatures.signatureForDoFn(doFn); - switch (pTransform.getSpec().getUrn()) { - case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: - case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN: - case PTransformTranslation.PAR_DO_TRANSFORM_URN: - mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload); - break; - case PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN: - case PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN: - case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: - mainOutputTag = - new TupleTag(Iterables.getOnlyElement(pTransform.getOutputsMap().keySet())); - break; - default: - throw new IllegalStateException( - String.format("Unknown urn: %s", pTransform.getSpec().getUrn())); - } - String mainInputTag = - Iterables.getOnlyElement( - Sets.difference( - pTransform.getInputsMap().keySet(), - Sets.union( - parDoPayload.getSideInputsMap().keySet(), - parDoPayload.getTimerSpecsMap().keySet()))); - PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag)); - inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId()); - if (inputCoder instanceof KvCoder - // TODO: Stop passing windowed value coders within PCollections. - || (inputCoder instanceof WindowedValue.WindowedValueCoder - && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) { - this.keyCoder = - inputCoder instanceof WindowedValueCoder - ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder() - : ((KvCoder) inputCoder).getKeyCoder(); - } else { - this.keyCoder = null; - } - if (inputCoder instanceof SchemaCoder - // TODO: Stop passing windowed value coders within PCollections. - || (inputCoder instanceof WindowedValue.WindowedValueCoder - && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof SchemaCoder))) { - this.schemaCoder = - inputCoder instanceof WindowedValueCoder - ? (SchemaCoder) ((WindowedValueCoder) inputCoder).getValueCoder() - : ((SchemaCoder) inputCoder); - } else { - this.schemaCoder = null; - } - - windowingStrategy = - (WindowingStrategy) - rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId()); - windowCoder = windowingStrategy.getWindowFn().windowCoder(); - - outputCoders = Maps.newHashMap(); - for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - TupleTag outputTag = new TupleTag<>(entry.getKey()); - RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue()); - Coder outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId()); - if (outputCoder instanceof WindowedValueCoder) { - outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder(); - } - outputCoders.put(outputTag, outputCoder); - } - Coder outputCoder = (Coder) outputCoders.get(mainOutputTag); - mainOutputSchemaCoder = - (outputCoder instanceof SchemaCoder) ? (SchemaCoder) outputCoder : null; - - // Build the map from tag id to side input specification - for (Map.Entry entry : - parDoPayload.getSideInputsMap().entrySet()) { - String sideInputTag = entry.getKey(); - RunnerApi.SideInput sideInput = entry.getValue(); - checkArgument( - Materializations.MULTIMAP_MATERIALIZATION_URN.equals( - sideInput.getAccessPattern().getUrn()), - "This SDK is only capable of dealing with %s materializations " - + "but was asked to handle %s for PCollectionView with tag %s.", - Materializations.MULTIMAP_MATERIALIZATION_URN, - sideInput.getAccessPattern().getUrn(), - sideInputTag); - - PCollection sideInputPCollection = - pCollections.get(pTransform.getInputsOrThrow(sideInputTag)); - WindowingStrategy sideInputWindowingStrategy = - rehydratedComponents.getWindowingStrategy( - sideInputPCollection.getWindowingStrategyId()); - tagToSideInputSpecMapBuilder.put( - new TupleTag<>(entry.getKey()), - SideInputSpec.create( - rehydratedComponents.getCoder(sideInputPCollection.getCoderId()), - sideInputWindowingStrategy.getWindowFn().windowCoder(), - PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()), - PCollectionViewTranslation.windowMappingFnFromProto( - entry.getValue().getWindowMappingFn()))); - } - } catch (IOException exn) { - throw new IllegalArgumentException("Malformed ParDoPayload", exn); - } - - ImmutableListMultimap.Builder>> - localNameToConsumerBuilder = ImmutableListMultimap.builder(); - for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - localNameToConsumerBuilder.putAll( - entry.getKey(), pCollectionConsumerRegistry.getMultiplexingConsumer(entry.getValue())); - } - localNameToConsumer = localNameToConsumerBuilder.build(); - tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build(); - this.splitListener = splitListener; - } - } - static class Factory implements PTransformRunnerFactory< FnApiDoFnRunner> { @@ -310,9 +149,11 @@ static class Factory PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { - Context context = - new Context<>( + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { + + FnApiDoFnRunner runner = + new FnApiDoFnRunner<>( pipelineOptions, beamFnStateClient, pTransformId, @@ -322,10 +163,8 @@ static class Factory coders, windowingStrategies, pCollectionConsumerRegistry, - splitListener); - - FnApiDoFnRunner runner = - new FnApiDoFnRunner<>(context); + splitListener, + bundleFinalizer); // Register the appropriate handlers. startFunctionRegistry.register(pTransformId, runner::startBundle); @@ -360,10 +199,10 @@ static class Factory pTransform.getInputsOrThrow(mainInput), pTransformId, (FnDataReceiver) mainInputConsumer); // Register as a consumer for each timer PCollection. - for (String localName : context.parDoPayload.getTimerSpecsMap().keySet()) { + for (String localName : runner.parDoPayload.getTimerSpecsMap().keySet()) { TimeDomain timeDomain = DoFnSignatures.getTimerSpecOrThrow( - context.doFnSignature.timerDeclarations().get(localName), context.doFn) + runner.doFnSignature.timerDeclarations().get(localName), runner.doFn) .getTimeDomain(); pCollectionConsumerRegistry.register( pTransform.getInputsOrThrow(localName), @@ -382,15 +221,36 @@ static class Factory ////////////////////////////////////////////////////////////////////////////////////////////////// - private final Context context; + private final PipelineOptions pipelineOptions; + private final BeamFnStateClient beamFnStateClient; + private final String pTransformId; + private final PTransform pTransform; + private final Supplier processBundleInstructionId; + private final RehydratedComponents rehydratedComponents; + private final DoFn doFn; + private final DoFnSignature doFnSignature; + private final TupleTag mainOutputTag; + private final Coder inputCoder; + private final SchemaCoder schemaCoder; + private final Coder keyCoder; + private final SchemaCoder mainOutputSchemaCoder; + private final Coder windowCoder; + private final WindowingStrategy windowingStrategy; + private final Map, SideInputSpec> tagToSideInputSpecMap; + private final Map, Coder> outputCoders; + private final ParDoPayload parDoPayload; + private final ListMultimap>> localNameToConsumer; + private final BundleSplitListener splitListener; + private final BundleFinalizer bundleFinalizer; + private final Collection>> mainOutputConsumers; private final String mainInputId; private FnApiStateAccessor stateAccessor; private final DoFnInvoker doFnInvoker; - private final DoFn.StartBundleContext startBundleContext; + private final StartBundleArgumentProvider startBundleArgumentProvider; private final ProcessBundleContext processContext; private final OnTimerContext onTimerContext; - private final DoFn.FinishBundleContext finishBundleContext; + private final FinishBundleArgumentProvider finishBundleArgumentProvider; /** * Only set for {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} and {@link @@ -427,29 +287,162 @@ static class Factory /** Only valid during {@link #processTimer}, null otherwise. */ private TimeDomain currentTimeDomain; - FnApiDoFnRunner(Context context) { - this.context = context; + FnApiDoFnRunner( + PipelineOptions pipelineOptions, + BeamFnStateClient beamFnStateClient, + String pTransformId, + PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Map windowingStrategies, + PCollectionConsumerRegistry pCollectionConsumerRegistry, + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { + this.pipelineOptions = pipelineOptions; + this.beamFnStateClient = beamFnStateClient; + this.pTransformId = pTransformId; + this.pTransform = pTransform; + this.processBundleInstructionId = processBundleInstructionId; + ImmutableMap.Builder, SideInputSpec> tagToSideInputSpecMapBuilder = + ImmutableMap.builder(); try { - this.mainInputId = ParDoTranslation.getMainInputName(context.pTransform); + rehydratedComponents = + RehydratedComponents.forComponents( + RunnerApi.Components.newBuilder() + .putAllCoders(coders) + .putAllPcollections(pCollections) + .putAllWindowingStrategies(windowingStrategies) + .build()) + .withPipeline(Pipeline.create()); + parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload()); + doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload); + doFnSignature = DoFnSignatures.signatureForDoFn(doFn); + switch (pTransform.getSpec().getUrn()) { + case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: + case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN: + case PTransformTranslation.PAR_DO_TRANSFORM_URN: + mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload); + break; + case PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN: + case PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN: + case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: + mainOutputTag = + new TupleTag(Iterables.getOnlyElement(pTransform.getOutputsMap().keySet())); + break; + default: + throw new IllegalStateException( + String.format("Unknown urn: %s", pTransform.getSpec().getUrn())); + } + String mainInputTag = + Iterables.getOnlyElement( + Sets.difference( + pTransform.getInputsMap().keySet(), + Sets.union( + parDoPayload.getSideInputsMap().keySet(), + parDoPayload.getTimerSpecsMap().keySet()))); + PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag)); + inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId()); + if (inputCoder instanceof KvCoder + // TODO: Stop passing windowed value coders within PCollections. + || (inputCoder instanceof WindowedValue.WindowedValueCoder + && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) { + this.keyCoder = + inputCoder instanceof WindowedValueCoder + ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder() + : ((KvCoder) inputCoder).getKeyCoder(); + } else { + this.keyCoder = null; + } + if (inputCoder instanceof SchemaCoder + // TODO: Stop passing windowed value coders within PCollections. + || (inputCoder instanceof WindowedValue.WindowedValueCoder + && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof SchemaCoder))) { + this.schemaCoder = + inputCoder instanceof WindowedValueCoder + ? (SchemaCoder) ((WindowedValueCoder) inputCoder).getValueCoder() + : ((SchemaCoder) inputCoder); + } else { + this.schemaCoder = null; + } + + windowingStrategy = + (WindowingStrategy) + rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId()); + windowCoder = windowingStrategy.getWindowFn().windowCoder(); + + outputCoders = Maps.newHashMap(); + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { + TupleTag outputTag = new TupleTag<>(entry.getKey()); + RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue()); + Coder outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId()); + if (outputCoder instanceof WindowedValueCoder) { + outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder(); + } + outputCoders.put(outputTag, outputCoder); + } + Coder outputCoder = (Coder) outputCoders.get(mainOutputTag); + mainOutputSchemaCoder = + (outputCoder instanceof SchemaCoder) ? (SchemaCoder) outputCoder : null; + + // Build the map from tag id to side input specification + for (Map.Entry entry : + parDoPayload.getSideInputsMap().entrySet()) { + String sideInputTag = entry.getKey(); + RunnerApi.SideInput sideInput = entry.getValue(); + checkArgument( + Materializations.MULTIMAP_MATERIALIZATION_URN.equals( + sideInput.getAccessPattern().getUrn()), + "This SDK is only capable of dealing with %s materializations " + + "but was asked to handle %s for PCollectionView with tag %s.", + Materializations.MULTIMAP_MATERIALIZATION_URN, + sideInput.getAccessPattern().getUrn(), + sideInputTag); + + PCollection sideInputPCollection = + pCollections.get(pTransform.getInputsOrThrow(sideInputTag)); + WindowingStrategy sideInputWindowingStrategy = + rehydratedComponents.getWindowingStrategy( + sideInputPCollection.getWindowingStrategyId()); + tagToSideInputSpecMapBuilder.put( + new TupleTag<>(entry.getKey()), + SideInputSpec.create( + rehydratedComponents.getCoder(sideInputPCollection.getCoderId()), + sideInputWindowingStrategy.getWindowFn().windowCoder(), + PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()), + PCollectionViewTranslation.windowMappingFnFromProto( + entry.getValue().getWindowMappingFn()))); + } + } catch (IOException exn) { + throw new IllegalArgumentException("Malformed ParDoPayload", exn); + } + + ImmutableListMultimap.Builder>> + localNameToConsumerBuilder = ImmutableListMultimap.builder(); + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { + localNameToConsumerBuilder.putAll( + entry.getKey(), pCollectionConsumerRegistry.getMultiplexingConsumer(entry.getValue())); + } + localNameToConsumer = localNameToConsumerBuilder.build(); + tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build(); + this.splitListener = splitListener; + this.bundleFinalizer = bundleFinalizer; + + try { + this.mainInputId = ParDoTranslation.getMainInputName(pTransform); } catch (IOException e) { throw new RuntimeException(e); } this.mainOutputConsumers = (Collection>>) - (Collection) context.localNameToConsumer.get(context.mainOutputTag.getId()); - this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.parDoPayload); - this.sideInputMapping = ParDoTranslation.getSideInputMapping(context.parDoPayload); - this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn); + (Collection) localNameToConsumer.get(mainOutputTag.getId()); + this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(parDoPayload); + this.sideInputMapping = ParDoTranslation.getSideInputMapping(parDoPayload); + this.doFnInvoker = DoFnInvokers.invokerFor(doFn); this.doFnInvoker.invokeSetup(); - this.startBundleContext = - this.context.doFn.new StartBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; - } - }; - switch (context.pTransform.getSpec().getUrn()) { + this.startBundleArgumentProvider = new StartBundleArgumentProvider(); + switch (pTransform.getSpec().getUrn()) { case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: // OutputT == RestrictionT this.processContext = @@ -514,36 +507,12 @@ public Instant timestamp(DoFn doFn) { break; default: throw new IllegalStateException( - String.format("Unknown URN %s", context.pTransform.getSpec().getUrn())); + String.format("Unknown URN %s", pTransform.getSpec().getUrn())); } this.onTimerContext = new OnTimerContext(); - this.finishBundleContext = - this.context.doFn.new FinishBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; - } - - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - outputTo( - mainOutputConsumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - - @Override - public void output( - TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - }; - - switch (context.pTransform.getSpec().getUrn()) { + this.finishBundleArgumentProvider = new FinishBundleArgumentProvider(); + + switch (pTransform.getSpec().getUrn()) { case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: this.convertSplitResultToWindowedSplitResult = (splitResult) -> @@ -606,7 +575,7 @@ public Object restriction() { throw new IllegalStateException( String.format( "Unimplemented split conversion handler for %s.", - context.pTransform.getSpec().getUrn())); + pTransform.getSpec().getUrn())); }; } } @@ -614,17 +583,17 @@ public Object restriction() { public void startBundle() { this.stateAccessor = new FnApiStateAccessor( - context.pipelineOptions, - context.ptransformId, - context.processBundleInstructionId, - context.tagToSideInputSpecMap, - context.beamFnStateClient, - context.keyCoder, - (Coder) context.windowCoder, + pipelineOptions, + pTransformId, + processBundleInstructionId, + tagToSideInputSpecMap, + beamFnStateClient, + keyCoder, + (Coder) windowCoder, () -> MoreObjects.firstNonNull(currentElement, currentTimer), () -> currentWindow); - doFnInvoker.invokeStartBundle(startBundleContext); + doFnInvoker.invokeStartBundle(startBundleArgumentProvider); } public void processElementForParDo(WindowedValue elem) { @@ -729,8 +698,7 @@ public void processElementForElementAndRestriction(WindowedValue>>> consumers = - (Collection) context.localNameToConsumer.get(timerId); + (Collection) localNameToConsumer.get(timerId); if (currentOutputTimestamp == null) { if (TimeDomain.EVENT_TIME.equals(timeDomain)) { @@ -965,6 +931,83 @@ public org.apache.beam.sdk.state.Timer get(String timerId) { } } + private class StartBundleArgumentProvider extends BaseArgumentProvider { + private class Context extends DoFn.StartBundleContext { + Context() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + } + + private final Context context = new Context(); + + @Override + public DoFn.StartBundleContext startBundleContext(DoFn doFn) { + return context; + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; + } + + @Override + public String getErrorContext() { + return "FnApiDoFnRunner/StartBundle"; + } + } + + private class FinishBundleArgumentProvider extends BaseArgumentProvider { + private class Context extends DoFn.FinishBundleContext { + Context() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + outputTo( + mainOutputConsumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + + @Override + public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + } + + private final Context context = new Context(); + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return context; + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; + } + + @Override + public String getErrorContext() { + return "FnApiDoFnRunner/FinishBundle"; + } + } + /** * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}. */ @@ -972,7 +1015,7 @@ private class ProcessBundleContext extends DoFn.ProcessContext implements DoFnInvoker.ArgumentProvider { private ProcessBundleContext() { - context.doFn.super(); + doFn.super(); } @Override @@ -1043,12 +1086,17 @@ public OutputReceiver outputReceiver(DoFn doFn) { @Override public OutputReceiver outputRowReceiver(DoFn doFn) { - return DoFnOutputReceivers.rowReceiver(this, null, context.mainOutputSchemaCoder); + return DoFnOutputReceivers.rowReceiver(this, null, mainOutputSchemaCoder); } @Override public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedMultiReceiver(this, context.outputCoders); + return DoFnOutputReceivers.windowedMultiReceiver(this, outputCoders); + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; } @Override @@ -1069,11 +1117,11 @@ public DoFn.OnTimerContext onTimerContext(DoFn @Override public State state(String stateId, boolean alwaysFetched) { - StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); StateSpec spec; try { - spec = (StateSpec) stateDeclaration.field().get(context.doFn); + spec = (StateSpec) stateDeclaration.field().get(doFn); } catch (IllegalAccessException e) { throw new RuntimeException(e); } @@ -1103,12 +1151,12 @@ public TimerMap timerFamily(String tagId) { @Override public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; + return pipelineOptions; } @Override public PipelineOptions pipelineOptions() { - return context.pipelineOptions; + return pipelineOptions; } @Override @@ -1131,7 +1179,7 @@ public void output(TupleTag tag, T output) { @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); + (Collection) localNameToConsumer.get(tag.getId()); if (consumers == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } @@ -1166,110 +1214,130 @@ public void updateWatermark(Instant watermark) { } /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer @OnTimer}. */ - private class OnTimerContext extends DoFn.OnTimerContext - implements DoFnInvoker.ArgumentProvider { + private class OnTimerContext extends BaseArgumentProvider { + private class Context extends DoFn.OnTimerContext { + private Context() { + doFn.super(); + } - private OnTimerContext() { - context.doFn.super(); - } + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } - @Override - public BoundedWindow window() { - return currentWindow; - } + @Override + public BoundedWindow window() { + return currentWindow; + } - @Override - public PaneInfo paneInfo(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access paneInfo outside of @ProcessElement methods."); - } + @Override + public void output(OutputT output) { + outputTo( + mainOutputConsumers, + WindowedValue.of( + output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.StartBundleContext startBundleContext(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + checkArgument( + !currentTimer.getTimestamp().isAfter(timestamp), + "Output time %s can not be before timer timestamp %s.", + timestamp, + currentTimer.getTimestamp()); + outputTo( + mainOutputConsumers, + WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.FinishBundleContext finishBundleContext( - DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } + @Override + public void output(TupleTag tag, T output) { + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo( + consumers, + WindowedValue.of( + output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.ProcessContext processContext(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } + @Override + public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + checkArgument( + !currentTimer.getTimestamp().isAfter(timestamp), + "Output time %s can not be before timer timestamp %s.", + timestamp, + currentTimer.getTimestamp()); + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public InputT element(DoFn doFn) { - throw new UnsupportedOperationException("Element parameters are not supported."); - } + @Override + public TimeDomain timeDomain() { + return currentTimeDomain; + } - @Override - public InputT sideInput(String tagId) { - throw new UnsupportedOperationException("SideInput parameters are not supported."); - } + @Override + public Instant fireTimestamp() { + return currentTimer.getValue().getValue().getTimestamp(); + } - @Override - public Object schemaElement(int index) { - throw new UnsupportedOperationException("Element parameters are not supported."); + @Override + public Instant timestamp() { + return currentTimer.getTimestamp(); + } } + private final Context context = new Context(); + @Override - public Instant timestamp(DoFn doFn) { - return timestamp(); + public BoundedWindow window() { + return currentWindow; } @Override - public String timerId(DoFn doFn) { - throw new UnsupportedOperationException("TimerId parameters are not supported."); + public Instant timestamp(DoFn doFn) { + return currentTimer.getTimestamp(); } @Override public TimeDomain timeDomain(DoFn doFn) { - return timeDomain(); + return currentTimeDomain; } @Override public OutputReceiver outputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedReceiver(this, null); + return DoFnOutputReceivers.windowedReceiver(context, null); } @Override public OutputReceiver outputRowReceiver(DoFn doFn) { - return DoFnOutputReceivers.rowReceiver(this, null, context.mainOutputSchemaCoder); + return DoFnOutputReceivers.rowReceiver(context, null, mainOutputSchemaCoder); } @Override public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedMultiReceiver(this); - } - - @Override - public Object restriction() { - throw new UnsupportedOperationException("Restriction parameters are not supported."); + return DoFnOutputReceivers.windowedMultiReceiver(context); } @Override public DoFn.OnTimerContext onTimerContext(DoFn doFn) { - return this; - } - - @Override - public RestrictionTracker restrictionTracker() { - throw new UnsupportedOperationException("RestrictionTracker parameters are not supported."); + return context; } @Override public State state(String stateId, boolean alwaysFetched) { - StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); StateSpec spec; try { - spec = (StateSpec) stateDeclaration.field().get(context.doFn); + spec = (StateSpec) stateDeclaration.field().get(doFn); } catch (IllegalAccessException e) { throw new RuntimeException(e); } @@ -1294,78 +1362,17 @@ public org.apache.beam.sdk.state.Timer timer(String timerId) { @Override public TimerMap timerFamily(String tagId) { // TODO: implement timerFamily - throw new UnsupportedOperationException("TimerFamily parameters are not supported."); - } - - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; + return super.timerFamily(tagId); } @Override public PipelineOptions pipelineOptions() { - return context.pipelineOptions; - } - - @Override - public void output(OutputT output) { - outputTo( - mainOutputConsumers, - WindowedValue.of(output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void outputWithTimestamp(OutputT output, Instant timestamp) { - checkArgument( - !currentTimer.getTimestamp().isAfter(timestamp), - "Output time %s can not be before timer timestamp %s.", - timestamp, - currentTimer.getTimestamp()); - outputTo( - mainOutputConsumers, - WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void output(TupleTag tag, T output) { - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo( - consumers, - WindowedValue.of(output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { - checkArgument( - !currentTimer.getTimestamp().isAfter(timestamp), - "Output time %s can not be before timer timestamp %s.", - timestamp, - currentTimer.getTimestamp()); - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + return pipelineOptions; } @Override - public TimeDomain timeDomain() { - return currentTimeDomain; - } - - @Override - public Instant fireTimestamp() { - return currentTimer.getValue().getValue().getTimestamp(); - } - - @Override - public Instant timestamp() { - return currentTimer.getTimestamp(); + public String getErrorContext() { + return "FnApiDoFnRunner/OnTimer"; } } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 579d44707de16..ac79161edda70 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -23,6 +23,7 @@ import java.util.function.Function; import org.apache.beam.fn.harness.control.AddHarnessIdInterceptor; import org.apache.beam.fn.harness.control.BeamFnControlClient; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler; import org.apache.beam.fn.harness.control.ProcessBundleHandler; import org.apache.beam.fn.harness.control.RegisterHandler; import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; @@ -190,10 +191,20 @@ public static void main( new BeamFnStateGrpcClientCache( idGenerator, channelFactory::forDescriptor, outboundObserverFactory); + FinalizeBundleHandler finalizeBundleHandler = + new FinalizeBundleHandler(options.as(GcsOptions.class).getExecutorService()); + ProcessBundleHandler processBundleHandler = new ProcessBundleHandler( - options, fnApiRegistry::getById, beamFnDataMultiplexer, beamFnStateGrpcClientCache); + options, + fnApiRegistry::getById, + beamFnDataMultiplexer, + beamFnStateGrpcClientCache, + finalizeBundleHandler); handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER, fnApiRegistry::register); + handlers.put( + BeamFnApi.InstructionRequest.RequestCase.FINALIZE_BUNDLE, + finalizeBundleHandler::finalizeBundle); // TODO(BEAM-6597): Collect MonitoringInfos in ProcessBundleProgressResponses. handlers.put( BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java index 0cd5dcae08542..d680b471f46a9 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.function.ThrowingFunction; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -110,7 +111,8 @@ public Mapper createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { FnDataReceiver> consumer = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java index 389331f87c218..c21e9e2d22579 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java @@ -32,6 +32,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; /** A factory able to instantiate an appropriate handler for a given PTransform. */ public interface PTransformRunnerFactory { @@ -59,6 +60,9 @@ public interface PTransformRunnerFactory { * @param finishFunctionRegistry A class to register a finish bundle handler with. * @param addTearDownFunction A consumer to register a tear down handler with. * @param splitListener A listener to be invoked when the PTransform splits itself. + * @param bundleFinalizer Register callbacks that will be invoked when the runner completes the + * bundle. The specified instant provides the timeout on how long the finalization callback is + * valid for. */ T createRunnerForPTransform( PipelineOptions pipelineOptions, @@ -74,7 +78,8 @@ T createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException; /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java new file mode 100644 index 0000000000000..985a9cd00baca --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.control; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.PriorityQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A bundle finalization handler that expires entries after a specified amount of time. + * + *

Callers should register new callbacks via {@link #registerCallbacks} and fire existing + * callbacks using {@link #finalizeBundle}. + * + *

See Apache Beam Portability API: How to + * Finalize Bundles for further details. + */ +public class FinalizeBundleHandler { + + /** A {@link BundleFinalizer.Callback} and expiry time pair. */ + @AutoValue + abstract static class CallbackRegistration { + public static CallbackRegistration create( + Instant expiryTime, BundleFinalizer.Callback callback) { + return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback); + } + + public abstract Instant getExpiryTime(); + + public abstract BundleFinalizer.Callback getCallback(); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(FinalizeBundleHandler.class); + private final ConcurrentMap> bundleFinalizationCallbacks; + private final PriorityQueue> cleanUpQueue; + private final Future cleanUpResult; + + public FinalizeBundleHandler(ExecutorService executorService) { + this.bundleFinalizationCallbacks = new ConcurrentHashMap<>(); + this.cleanUpQueue = + new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp)); + this.cleanUpResult = + executorService.submit( + (Callable) + () -> { + while (true) { + synchronized (cleanUpQueue) { + TimestampedValue expiryTime = cleanUpQueue.peek(); + + // Wait until we have at least one element. We are notified on each element + // being added. + while (expiryTime == null) { + cleanUpQueue.wait(); + expiryTime = cleanUpQueue.peek(); + } + + // Wait until the current time has past the expiry time for the head of the + // queue. + // We are notified on each element being added. + Instant now = Instant.now(); + while (expiryTime.getTimestamp().isAfter(now)) { + Duration timeDifference = new Duration(now, expiryTime.getTimestamp()); + cleanUpQueue.wait(timeDifference.getMillis()); + expiryTime = cleanUpQueue.peek(); + now = Instant.now(); + } + + bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue()); + } + } + }); + } + + public void registerCallbacks(String bundleId, Collection callbacks) { + if (callbacks.isEmpty()) { + return; + } + + Collection priorCallbacks = + bundleFinalizationCallbacks.putIfAbsent(bundleId, callbacks); + checkState( + priorCallbacks == null, + "Expected to not have any past callbacks for bundle %s but found %s.", + bundleId, + priorCallbacks); + long expiryTimeMillis = Long.MIN_VALUE; + for (CallbackRegistration callback : callbacks) { + expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis()); + } + synchronized (cleanUpQueue) { + cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis))); + cleanUpQueue.notify(); + } + } + + public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.InstructionRequest request) + throws Exception { + String bundleId = request.getFinalizeBundle().getInstructionId(); + + Collection callbacks = bundleFinalizationCallbacks.remove(bundleId); + + if (callbacks == null) { + // We have already processed the callbacks on a prior bundle finalization attempt + return BeamFnApi.InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); + } + + Collection failures = new ArrayList<>(); + for (CallbackRegistration callback : callbacks) { + try { + callback.getCallback().onBundleSuccess(); + } catch (Exception e) { + failures.add(e); + } + } + if (!failures.isEmpty()) { + Exception e = + new Exception( + String.format("Failed to handle bundle finalization for bundle %s.", bundleId)); + for (Exception failure : failures) { + e.addSuppressed(failure); + } + throw e; + } + + return BeamFnApi.InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 00882b5607812..48e24509bf96d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -21,6 +21,7 @@ import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -36,6 +37,7 @@ import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry; import org.apache.beam.fn.harness.data.PTransformFunctionRegistry; @@ -63,18 +65,21 @@ import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,6 +126,7 @@ public class ProcessBundleHandler { private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; + private final FinalizeBundleHandler finalizeBundleHandler; private final Map urnToPTransformRunnerFactoryMap; private final PTransformRunnerFactory defaultPTransformRunnerFactory; @VisibleForTesting final BundleProcessorCache bundleProcessorCache; @@ -129,12 +135,14 @@ public ProcessBundleHandler( PipelineOptions options, Function fnApiRegistry, BeamFnDataClient beamFnDataClient, - BeamFnStateGrpcClientCache beamFnStateGrpcClientCache) { + BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, + FinalizeBundleHandler finalizeBundleHandler) { this( options, fnApiRegistry, beamFnDataClient, beamFnStateGrpcClientCache, + finalizeBundleHandler, REGISTERED_RUNNER_FACTORIES, new BundleProcessorCache()); } @@ -145,12 +153,14 @@ public ProcessBundleHandler( Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, + FinalizeBundleHandler finalizeBundleHandler, Map urnToPTransformRunnerFactoryMap, BundleProcessorCache bundleProcessorCache) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache; + this.finalizeBundleHandler = finalizeBundleHandler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = new UnknownPTransformRunnerFactory(urnToPTransformRunnerFactoryMap.keySet()); @@ -170,7 +180,8 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Recursively ensure that all consumers of the output PCollection have been created. @@ -192,7 +203,8 @@ private void createRunnerAndConsumersForPTransformRecursively( startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener); + splitListener, + bundleFinalizer); } } @@ -225,7 +237,8 @@ private void createRunnerAndConsumersForPTransformRecursively( startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener); + splitListener, + bundleFinalizer); processedPTransformIds.add(pTransformId); } } @@ -298,6 +311,14 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) { response.addMonitoringInfos(mi); } + + if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) { + finalizeBundleHandler.registerCallbacks( + bundleProcessor.getInstructionId(), + ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations())); + response.setRequiresFinalization(true); + } + bundleProcessorCache.release( request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); } @@ -382,6 +403,16 @@ private BundleProcessor createBundleProcessor( } }; + Collection bundleFinalizationCallbackRegistrations = new ArrayList<>(); + BundleFinalizer bundleFinalizer = + new BundleFinalizer() { + @Override + public void afterBundleCommit(Instant callbackExpiry, Callback callback) { + bundleFinalizationCallbackRegistrations.add( + CallbackRegistration.create(callbackExpiry, callback)); + } + }; + BundleProcessor bundleProcessor = BundleProcessor.create( startFunctionRegistry, @@ -392,7 +423,8 @@ private BundleProcessor createBundleProcessor( metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbackRegistrations); // Create a BeamFnStateClient for (Map.Entry entry : @@ -420,7 +452,8 @@ private BundleProcessor createBundleProcessor( startFunctionRegistry, finishFunctionRegistry, tearDownFunctions::add, - splitListener); + splitListener, + bundleFinalizer); } return bundleProcessor; } @@ -517,7 +550,8 @@ public static BundleProcessor create( MetricsContainerStepMap metricsContainerRegistry, ExecutionStateTracker stateTracker, HandleStateCallsForBundle beamFnStateClient, - QueueingBeamFnDataClient queueingClient) { + QueueingBeamFnDataClient queueingClient, + Collection bundleFinalizationCallbackRegistrations) { return new AutoValue_ProcessBundleHandler_BundleProcessor( startFunctionRegistry, finishFunctionRegistry, @@ -527,7 +561,8 @@ public static BundleProcessor create( metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbackRegistrations); } private String instructionId; @@ -550,6 +585,8 @@ public static BundleProcessor create( abstract QueueingBeamFnDataClient getQueueingClient(); + abstract Collection getBundleFinalizationCallbackRegistrations(); + String getInstructionId() { return this.instructionId; } @@ -566,6 +603,7 @@ void reset() { getMetricsContainerRegistry().reset(); getStateTracker().reset(); ExecutionStateSampler.instance().reset(); + getBundleFinalizationCallbackRegistrations().clear(); } } @@ -659,7 +697,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { String message = String.format( "No factory registered for %s, known factories %s", diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java index c27400f5f05fd..f00c4f941e8ae 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java @@ -206,7 +206,8 @@ public Coder windowCoder() { null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); WindowedValue value = WindowedValue.of( diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java index dc319d7312628..8b9daab05656d 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java @@ -175,7 +175,8 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(teardownFunctions, empty()); @@ -476,6 +477,7 @@ private BeamFnDataReadRunner createReadRunner( startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index dade2d9094899..5085fe9320480 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -149,7 +149,8 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(teardownFunctions, empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java index 0f461d20c4650..e419783da7ae3 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java @@ -173,7 +173,8 @@ public void testCreatingAndProcessingSourceFromFactory() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); // This is testing a deprecated way of running sources and should be removed // once all source definitions are instead propagated along the input edge. diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java index 7f5be398a148d..0b0002f72b869 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java @@ -157,6 +157,7 @@ public void testPrecombine() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -232,6 +233,7 @@ public void testMergeAccumulators() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -295,6 +297,7 @@ public void testExtractOutputs() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -358,6 +361,7 @@ public void testCombineGroupedValues() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java index dc78ea2b22c84..d9f55f76015cc 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java @@ -92,7 +92,8 @@ public void testCreatingAndProcessingDoFlatten() throws Exception { null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); mainOutputValues.clear(); assertThat( @@ -160,7 +161,8 @@ public void testFlattenWithDuplicateInputCollectionProducesMultipleOutputs() thr null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); mainOutputValues.clear(); assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget")); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index cf75bee2302e1..d5e8b6df6b0e8 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -239,7 +239,8 @@ public void testUsingUserState() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -414,7 +415,8 @@ public void testBasicWithSideInputsAndOutputs() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -552,7 +554,8 @@ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -667,7 +670,8 @@ public void testUsingMetrics() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -878,7 +882,8 @@ public void testTimers() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1153,7 +1158,8 @@ public void split( primarySplits.addAll(primaryRoots); residualSplits.addAll(residualRoots); } - }); + }, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1259,7 +1265,8 @@ public void testProcessElementForPairWithRestriction() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* bundleSplitListener */); + null /* bundleSplitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1348,7 +1355,8 @@ public void testProcessElementForSplitAndSizeRestriction() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* bundleSplitListener */); + null /* bundleSplitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java index 956e295dd31e7..7d9abbbde6885 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java @@ -93,7 +93,8 @@ public void testValueOnlyMapping() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); @@ -138,7 +139,8 @@ public void testFullWindowedValueMapping() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); @@ -182,7 +184,8 @@ public void testFullWindowedValueMappingWithCompressedWindow() throws Exception startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java new file mode 100644 index 0000000000000..a760d22b78af0 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.control; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.StringContains.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FinalizeBundleHandler}. */ +@RunWith(JUnit4.class) +public class FinalizeBundleHandlerTest { + private static final String INSTRUCTION_ID = "instructionId"; + private static final InstructionResponse SUCCESSFUL_RESPONSE = + InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()) + .build(); + + @Test + public void testRegistrationAndCallback() throws Exception { + AtomicBoolean wasCalled1 = new AtomicBoolean(); + AtomicBoolean wasCalled2 = new AtomicBoolean(); + List callbacks = new ArrayList<>(); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), () -> wasCalled1.set(true))); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), () -> wasCalled2.set(true))); + + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + handler.registerCallbacks("test", callbacks); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + assertTrue(wasCalled1.get()); + assertTrue(wasCalled2.get()); + } + + @Test + public void testFinalizationIgnoresMissingBundleIds() throws Exception { + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + } + + @Test + public void testFinalizationContinuesToNextCallbackEvenInFailure() throws Exception { + List callbacks = new ArrayList<>(); + AtomicBoolean wasCalled1 = new AtomicBoolean(); + AtomicBoolean wasCalled2 = new AtomicBoolean(); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), + () -> { + wasCalled1.set(true); + throw new Exception("testException1"); + })); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), + () -> { + wasCalled2.set(true); + throw new Exception("testException2"); + })); + + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + handler.registerCallbacks("test", callbacks); + + try { + handler.finalizeBundle(requestFor("test")); + fail(); + } catch (Exception e) { + assertThat(e.getMessage(), containsString("Failed to handle bundle finalization for bundle")); + assertEquals(2, e.getSuppressed().length); + assertTrue(wasCalled1.get()); + assertTrue(wasCalled2.get()); + } + } + + private static InstructionRequest requestFor(String bundleId) { + return InstructionRequest.newBuilder() + .setInstructionId(INSTRUCTION_ID) + .setFinalizeBundle(FinalizeBundleRequest.newBuilder().setInstructionId(bundleId).build()) + .build(); + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index f6e65537b6e11..f9aeb6b485f3e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -21,12 +21,15 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -34,6 +37,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -42,6 +46,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor; import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -71,6 +76,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.DoFnWithExecutionInformation; @@ -80,9 +86,11 @@ import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -213,6 +221,11 @@ QueueingBeamFnDataClient getQueueingClient() { return wrappedBundleProcessor.getQueueingClient(); } + @Override + Collection getBundleFinalizationCallbackRegistrations() { + return wrappedBundleProcessor.getBundleFinalizationCallbackRegistrations(); + } + @Override void reset() { resetCnt++; @@ -269,7 +282,8 @@ public void testOrderOfStartAndFinishCalls() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { transformsProcessed.add(pTransform); startFunctionRegistry.register( pTransformId, @@ -292,6 +306,7 @@ public void testOrderOfStartAndFinishCalls() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateClient */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, startFinishRecorder, DATA_OUTPUT_URN, startFinishRecorder), @@ -398,7 +413,8 @@ public void testOrderOfSetupTeardownCalls() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> null); + splitListener, + bundleFinalizer) -> null); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -406,6 +422,7 @@ public void testOrderOfSetupTeardownCalls() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateClient */, + null /* finalizeBundleHandler */, urnToPTransformRunnerFactoryMap, new BundleProcessorCache()); @@ -451,6 +468,7 @@ public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (pipelineOptions, @@ -466,7 +484,8 @@ public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> null), + splitListener, + bundleFinalizer) -> null), new TestBundleProcessorCache()); assertThat(TestBundleProcessor.resetCnt, equalTo(0)); @@ -510,6 +529,7 @@ public void testBundleProcessorReset() { PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class); PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class); Multimap allResiduals = mock(Multimap.class); + Collection bundleFinalizationCallbacks = mock(Collection.class); PCollectionConsumerRegistry pCollectionConsumerRegistry = mock(PCollectionConsumerRegistry.class); MetricsContainerStepMap metricsContainerRegistry = mock(MetricsContainerStepMap.class); @@ -527,7 +547,8 @@ public void testBundleProcessorReset() { metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbacks); bundleProcessor.reset(); verify(startFunctionRegistry, times(1)).reset(); @@ -536,6 +557,7 @@ public void testBundleProcessorReset() { verify(pCollectionConsumerRegistry, times(1)).reset(); verify(metricsContainerRegistry, times(1)).reset(); verify(stateTracker, times(1)).reset(); + verify(bundleFinalizationCallbacks, times(1)).clear(); } @Test @@ -556,6 +578,7 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (pipelineOptions, @@ -571,7 +594,8 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); throw new IllegalStateException("TestException"); @@ -584,6 +608,74 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { .build()); } + @Test + public void testBundleFinalizationIsPropagated() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms( + "2L", + RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + FinalizeBundleHandler mockFinalizeBundleHandler = mock(FinalizeBundleHandler.class); + BundleFinalizer.Callback mockCallback = mock(BundleFinalizer.Callback.class); + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + null /* beamFnStateGrpcClientCache */, + mockFinalizeBundleHandler, + ImmutableMap.of( + DATA_INPUT_URN, + (PTransformRunnerFactory) + (pipelineOptions, + beamFnDataClient, + beamFnStateClient, + pTransformId, + pTransform, + processBundleInstructionId, + pCollections, + coders, + windowingStrategies, + pCollectionConsumerRegistry, + startFunctionRegistry, + finishFunctionRegistry, + addTearDownFunction, + splitListener, + bundleFinalizer) -> { + startFunctionRegistry.register( + pTransformId, + () -> + bundleFinalizer.afterBundleCommit( + Instant.ofEpochMilli(42L), mockCallback)); + return null; + }), + new BundleProcessorCache()); + BeamFnApi.InstructionResponse.Builder response = + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId("2L") + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) + .build()); + + assertTrue(response.getProcessBundle().getRequiresFinalization()); + verify(mockFinalizeBundleHandler) + .registerCallbacks( + eq("2L"), + argThat( + (Collection arg) -> { + CallbackRegistration registration = Iterables.getOnlyElement(arg); + assertEquals(Instant.ofEpochMilli(42L), registration.getExpiryTime()); + assertSame(mockCallback, registration.getCallback()); + return true; + })); + } + @Test public void testPTransformStartExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = @@ -602,6 +694,7 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (PTransformRunnerFactory) @@ -618,7 +711,8 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); startFunctionRegistry.register( @@ -656,6 +750,7 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (PTransformRunnerFactory) @@ -672,7 +767,8 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); finishFunctionRegistry.register( @@ -746,6 +842,7 @@ public void testPendingStateCallsBlockTillCompletion() throws Exception { fnApiRegistry::get, beamFnDataClient, mockBeamFnStateGrpcClient, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, new PTransformRunnerFactory() { @@ -764,7 +861,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { startFunctionRegistry.register( pTransformId, () -> doStateCalls(beamFnStateClient)); @@ -807,6 +905,7 @@ public void testStateCallsFailIfNoStateApiServiceDescriptorSpecified() throws Ex fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, new PTransformRunnerFactory() { @@ -825,7 +924,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { startFunctionRegistry.register( pTransformId, () -> doStateCalls(beamFnStateClient));