diff --git a/runners/apex/build.gradle b/runners/apex/build.gradle index cb4d62126cd40..717996a4e0523 100644 --- a/runners/apex/build.gradle +++ b/runners/apex/build.gradle @@ -100,6 +100,7 @@ task validatesRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedPCollections' // TODO[BEAM-8304]: Support multiple side inputs with different coders. excludeCategories 'org.apache.beam.sdk.testing.UsesSideInputsWithDifferentCoders' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } // apex runner is run in embedded mode. Increase default HeapSize diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 19a272a9264ab..da8dae0280bf2 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -288,6 +288,25 @@ public boolean isRequiresTimeSortedInput() { return signature.processElement().requiresTimeSortedInput(); } + @Override + public boolean requestsFinalization() { + return (signature.startBundle() != null + && signature + .startBundle() + .extraParameters() + .contains(Parameter.bundleFinalizer())) + || (signature.processElement() != null + && signature + .processElement() + .extraParameters() + .contains(Parameter.bundleFinalizer())) + || (signature.finishBundle() != null + && signature + .finishBundle() + .extraParameters() + .contains(Parameter.bundleFinalizer())); + } + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { return restrictionCoderId; @@ -763,6 +782,8 @@ Map translateStateSpecs(SdkComponents components) boolean isRequiresTimeSortedInput(); + boolean requestsFinalization(); + String translateRestrictionCoderId(SdkComponents newComponents); } @@ -779,6 +800,7 @@ public static ParDoPayload payloadForParDoLike(ParDoLike parDo, SdkComponents co .setSplittable(parDo.isSplittable()) .setRequiresTimeSortedInput(parDo.isRequiresTimeSortedInput()) .setRestrictionCoderId(parDo.translateRestrictionCoderId(components)) + .setRequestsFinalization(parDo.requestsFinalization()) .build(); } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index adb018e48ff66..84ed10a3fe932 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -416,6 +416,25 @@ public boolean isRequiresTimeSortedInput() { return false; } + @Override + public boolean requestsFinalization() { + return (signature.startBundle() != null + && signature + .startBundle() + .extraParameters() + .contains(DoFnSignature.Parameter.bundleFinalizer())) + || (signature.processElement() != null + && signature + .processElement() + .extraParameters() + .contains(DoFnSignature.Parameter.bundleFinalizer())) + || (signature.finishBundle() != null + && signature + .finishBundle() + .extraParameters() + .contains(DoFnSignature.Parameter.bundleFinalizer())); + } + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { return restrictionCoderId; diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java index cf40ddedd400d..30a44da55e8f6 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java @@ -31,6 +31,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; @@ -132,10 +133,21 @@ public void setup() { @StartBundle public void startBundle(StartBundleContext c) { invoker.invokeStartBundle( - new DoFn.StartBundleContext() { + new BaseArgumentProvider() { @Override - public PipelineOptions getPipelineOptions() { - return c.getPipelineOptions(); + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + return new DoFn.StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + }; + } + + @Override + public String getErrorContext() { + return "SplittableParDoNaiveBounded/StartBundle"; } }); } @@ -174,23 +186,35 @@ public String getErrorContext() { @FinishBundle public void finishBundle(FinishBundleContext c) { invoker.invokeFinishBundle( - new DoFn.FinishBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return c.getPipelineOptions(); - } - + new BaseArgumentProvider() { @Override - public void output(@Nullable OutputT output, Instant timestamp, BoundedWindow window) { - throw new UnsupportedOperationException( - "Output from FinishBundle for SDF is not supported"); + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return new DoFn.FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public void output( + @Nullable OutputT output, Instant timestamp, BoundedWindow window) { + throw new UnsupportedOperationException( + "Output from FinishBundle for SDF is not supported"); + } + + @Override + public void output( + TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + throw new UnsupportedOperationException( + "Output from FinishBundle for SDF is not supported"); + } + }; } @Override - public void output( - TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - throw new UnsupportedOperationException( - "Output from FinishBundle for SDF is not supported"); + public String getErrorContext() { + return "SplittableParDoNaiveBounded/StartBundle"; } }); } @@ -317,6 +341,11 @@ public OutputReceiver getRowReceiver(TupleTag tag) { }; } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException(); + } + @Override public Object restriction() { return tracker.currentRestriction(); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index 5e73d37f3db34..2eccc615dacf9 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -20,7 +20,9 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import java.util.HashMap; import java.util.Map; @@ -47,6 +49,8 @@ import org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; @@ -64,12 +68,15 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.junit.Test; +import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; /** Tests for {@link ParDoTranslation}. */ +@RunWith(Enclosed.class) public class ParDoTranslationTest { /** Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. */ @@ -125,6 +132,7 @@ public void testToProto() throws Exception { for (PCollectionView view : parDo.getSideInputs().values()) { payload.getSideInputsOrThrow(view.getTagInternal().getId()); } + assertFalse(payload.getRequestsFinalization()); } @Test @@ -338,4 +346,73 @@ public int hashCode() { return StateTimerDropElementsFn.class.hashCode(); } } + + @RunWith(JUnit4.class) + public static class BundleFinalizerTranslation { + private static class StartBundleDoFn extends DoFn { + @StartBundle + public void startBundle(BundleFinalizer bundleFinalizer) {} + + @ProcessElement + public void processElement() {} + } + + private static class ProcessContextDoFn extends DoFn { + @ProcessElement + public void processElement(BundleFinalizer finalizer) {} + } + + private static class FinishBundleDoFn extends DoFn { + @FinishBundle + public void finishBundle(BundleFinalizer bundleFinalizer) {} + + @ProcessElement + public void processElement(BundleFinalizer finalizer) {} + } + + @Test + public void testStartBundle() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); + ParDoPayload payload = + ParDoTranslation.translateParDo( + ParDo.of(new StartBundleDoFn()) + .withOutputTags(new TupleTag<>(), TupleTagList.empty()), + DoFnSchemaInformation.create(), + TestPipeline.create(), + sdkComponents); + + assertTrue(payload.getRequestsFinalization()); + } + + @Test + public void testProcessContext() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); + ParDoPayload payload = + ParDoTranslation.translateParDo( + ParDo.of(new ProcessContextDoFn()) + .withOutputTags(new TupleTag<>(), TupleTagList.empty()), + DoFnSchemaInformation.create(), + TestPipeline.create(), + sdkComponents); + + assertTrue(payload.getRequestsFinalization()); + } + + @Test + public void testFinishBundle() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); + ParDoPayload payload = + ParDoTranslation.translateParDo( + ParDo.of(new FinishBundleDoFn()) + .withOutputTags(new TupleTag<>(), TupleTagList.empty()), + DoFnSchemaInformation.create(), + TestPipeline.create(), + sdkComponents); + + assertTrue(payload.getRequestsFinalization()); + } + } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 21bdf31d991de..010eb11a9425f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -169,6 +170,12 @@ public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { return DoFnOutputReceivers.windowedMultiReceiver(processContext, null); } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Not supported in non-portable SplittableDoFn"); + } + @Override public RestrictionTracker restrictionTracker() { return processContext.tracker; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index e4362ebed2072..71efa128209c6 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.state.TimerSpec; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; @@ -391,6 +392,12 @@ public TimerMap timerFamily(String tagId) { throw new UnsupportedOperationException( "Cannot access timer family outside of @ProcessElement and @OnTimer methods"); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** B A concrete implementation of {@link DoFn.FinishBundleContext}. */ @@ -538,6 +545,12 @@ public void output(OutputT output, Instant timestamp, BoundedWindow window) { public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { outputWindowedValue(tag, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** @@ -791,6 +804,12 @@ public TimerMap timerFamily(String timerFamilyId) { throw new RuntimeException(e); } } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } /** @@ -1014,6 +1033,12 @@ public void output(TupleTag tag, T output) { public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { outputWindowedValue(tag, WindowedValue.of(output, timestamp, window(), PaneInfo.NO_FIRING)); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "Bundle finalization is not supported in non-portable pipelines."); + } } private class TimerInternalsTimer implements Timer { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index 735e864e8e481..dc795a55dafef 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -36,6 +36,9 @@ import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; +import org.apache.beam.sdk.transforms.DoFn.ProcessContext; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; @@ -414,39 +417,62 @@ public String getErrorContext() { TimerInternals.TimerData.of(stateNamespace, wakeupTime, TimeDomain.PROCESSING_TIME)); } - private DoFn.StartBundleContext wrapContextAsStartBundle( + private DoFnInvoker.ArgumentProvider wrapContextAsStartBundle( final StartBundleContext baseContext) { - return fn.new StartBundleContext() { + return new BaseArgumentProvider() { @Override - public PipelineOptions getPipelineOptions() { - return baseContext.getPipelineOptions(); + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + return fn.new StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return baseContext.getPipelineOptions(); + } + }; } - }; - } - private DoFn.FinishBundleContext wrapContextAsFinishBundle( - final FinishBundleContext baseContext) { - return fn.new FinishBundleContext() { @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutput(); + public String getErrorContext() { + return "SplittableParDoViaKeyedWorkItems/StartBundle"; } + }; + } + private DoFnInvoker.ArgumentProvider wrapContextAsFinishBundle( + final FinishBundleContext baseContext) { + return new BaseArgumentProvider() { @Override - public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutput(); + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return fn.new FinishBundleContext() { + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + throwUnsupportedOutput(); + } + + @Override + public void output( + TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + throwUnsupportedOutput(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return baseContext.getPipelineOptions(); + } + + private void throwUnsupportedOutput() { + throw new UnsupportedOperationException( + String.format( + "Splittable DoFn can only output from @%s", + ProcessElement.class.getSimpleName())); + } + }; } @Override - public PipelineOptions getPipelineOptions() { - return baseContext.getPipelineOptions(); - } - - private void throwUnsupportedOutput() { - throw new UnsupportedOperationException( - String.format( - "Splittable DoFn can only output from @%s", - ProcessElement.class.getSimpleName())); + public String getErrorContext() { + return "SplittableParDoViaKeyedWorkItems/FinishBundle"; } }; } diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index 6d4652872e427..98dca8e8efe9d 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -110,6 +110,7 @@ task needsRunnerTests(type: Test) { // MetricsPusher isn't implemented in direct runner excludeCategories "org.apache.beam.sdk.testing.UsesMetricsPusher" excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } @@ -134,6 +135,7 @@ task validatesRunner(type: Test) { excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index ce8cf191af8d4..54eb4c54e133b 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -202,6 +202,7 @@ def createValidatesRunnerTask(Map m) { excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above100MB' excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesSystemMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' if (config.streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index 168860d38c4ff..691163902f3c0 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -149,6 +149,7 @@ def portableValidatesRunnerTask(String name, Boolean streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' if (streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime' excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' diff --git a/runners/gearpump/build.gradle b/runners/gearpump/build.gradle index f9bb2594f7695..57d574dec25df 100644 --- a/runners/gearpump/build.gradle +++ b/runners/gearpump/build.gradle @@ -85,6 +85,7 @@ task validatesRunnerStreaming(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index c4ccd645e3303..320ad0ca5ce08 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -141,6 +141,7 @@ def commonExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesTestStream', 'org.apache.beam.sdk.testing.UsesParDoLifecycle', 'org.apache.beam.sdk.testing.UsesMetricsPusher', + 'org.apache.beam.sdk.testing.UsesBundleFinalizer' ] def fnApiWorkerExcludeCategories = [ diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index e910e0116848b..2f99372a2b6ca 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.transforms.ParDo.SingleOutput; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -267,6 +268,25 @@ public boolean isRequiresTimeSortedInput() { return signature.processElement().requiresTimeSortedInput(); } + @Override + public boolean requestsFinalization() { + return (signature.startBundle() != null + && signature + .startBundle() + .extraParameters() + .contains(Parameter.bundleFinalizer())) + || (signature.processElement() != null + && signature + .processElement() + .extraParameters() + .contains(Parameter.bundleFinalizer())) + || (signature.finishBundle() != null + && signature + .finishBundle() + .extraParameters() + .contains(Parameter.bundleFinalizer())); + } + @Override public String translateRestrictionCoderId(SdkComponents newComponents) { if (signature.processElement().isSplittable()) { diff --git a/runners/jet/build.gradle b/runners/jet/build.gradle index f38835c0df4e0..491f3ceb73782 100644 --- a/runners/jet/build.gradle +++ b/runners/jet/build.gradle @@ -74,6 +74,7 @@ task validatesRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' //impulse doesn't cooperate properly with Jet when multiple cluster members are used exclude '**/SplittableDoFnTest.class' //Splittable DoFn functionality not yet in the runner + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } maxHeapSize = '4g' @@ -97,6 +98,7 @@ task needsRunnerTests(type: Test) { useJUnit { includeCategories "org.apache.beam.sdk.testing.NeedsRunner" excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index 9c529a60511bc..b241ec0024606 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -23,11 +23,6 @@ description = "Apache Beam :: Runners :: Portability :: Java" ext.summary = """A Java implementation of the Beam Model which utilizes the portability framework to execute user-definied functions.""" - -configurations { - validatesRunner -} - dependencies { compile library.java.vendored_guava_26_0_jre compile library.java.hamcrest_library diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle index a041eabec9ddd..ada72bde85dd3 100644 --- a/runners/samza/build.gradle +++ b/runners/samza/build.gradle @@ -87,6 +87,7 @@ task validatesRunner(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } } diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle index d5249d51dc4a1..2325e2b238d4c 100644 --- a/runners/spark/build.gradle +++ b/runners/spark/build.gradle @@ -148,6 +148,7 @@ task validatesRunnerBatch(type: Test) { // Portability excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } jvmArgs '-Xmx3g' } @@ -214,6 +215,7 @@ task validatesStructuredStreamingRunnerBatch(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse' excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } filter { // Combine with context not implemented diff --git a/runners/spark/job-server/build.gradle b/runners/spark/job-server/build.gradle index 0ff64883ec573..f65b95fff15d4 100644 --- a/runners/spark/job-server/build.gradle +++ b/runners/spark/job-server/build.gradle @@ -113,6 +113,7 @@ def portableValidatesRunnerTask(String name) { excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs' excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering' + excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' }, ) } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java new file mode 100644 index 0000000000000..5cd2c5b876ee6 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesBundleFinalizer.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; + +/** + * Category tag for validation tests which use {@link BundleFinalizer}. Tests tagged with {@link + * UsesBundleFinalizer} should be run for runners which support bundle finalization. + */ +public interface UsesBundleFinalizer {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 34937fb1af383..c43840f2d972d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -599,7 +599,6 @@ public interface MultiOutputReceiver { * Finalize Bundles for further details. * */ - // TODO: Add support for bundle finalization parameter. @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) @@ -1165,7 +1164,6 @@ public void populateDisplayData(DisplayData.Builder builder) {} * consumers without waiting for finalization to succeed. For pipelines that are sensitive to * duplicate messages, they must perform output deduplication in the pipeline. */ - // TODO: Add support for a deduplication PTransform. @Experimental(Kind.SPLITTABLE_DO_FN) public interface BundleFinalizer { /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index fb66a172c333c..ba6bc17a11b86 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -31,15 +31,20 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.StartBundleContext; import org.apache.beam.sdk.transforms.Materializations.MultimapView; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; @@ -325,6 +330,12 @@ public Timer timer(String timerId) { public TimerMap timerFamily(String tagId) { throw new UnsupportedOperationException("DoFnTester doesn't support timerFamily yet"); } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + "DoFnTester doesn't support bundleFinalizer yet"); + } }); } catch (UserCodeException e) { unwrapUserCodeException(e); @@ -459,38 +470,48 @@ public TupleTag getMainOutputTag() { return mainOutputTag; } - private class TestStartBundleContext extends DoFn.StartBundleContext { - - private TestStartBundleContext() { - fn.super(); + private class TestStartBundleContext extends BaseArgumentProvider { + @Override + public StartBundleContext startBundleContext(DoFn doFn) { + return fn.new StartBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + }; } @Override - public PipelineOptions getPipelineOptions() { - return options; + public String getErrorContext() { + return "DoFnTester/StartBundle"; } } - private class TestFinishBundleContext extends DoFn.FinishBundleContext { - - private TestFinishBundleContext() { - fn.super(); - } - + private class TestFinishBundleContext extends BaseArgumentProvider { @Override - public PipelineOptions getPipelineOptions() { - return options; - } + public FinishBundleContext finishBundleContext(DoFn doFn) { + return fn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return options; + } - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - output(mainOutputTag, output, timestamp, window); + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + output(mainOutputTag, output, timestamp, window); + } + + @Override + public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + getMutableOutput(tag) + .add(ValueInSingleWindow.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + }; } @Override - public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - getMutableOutput(tag) - .add(ValueInSingleWindow.of(output, timestamp, window, PaneInfo.NO_FIRING)); + public String getErrorContext() { + return "DoFnTester/FinishBundle"; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 5649ff6e1f2bb..503942c0e9c19 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter; @@ -107,6 +108,7 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory { public static final String ELEMENT_PARAMETER_METHOD = "element"; public static final String SCHEMA_ELEMENT_PARAMETER_METHOD = "schemaElement"; public static final String TIMESTAMP_PARAMETER_METHOD = "timestamp"; + public static final String BUNDLE_FINALIZER_PARAMETER_METHOD = "bundleFinalizer"; public static final String OUTPUT_ROW_RECEIVER_METHOD = "outputRowReceiver"; public static final String TIME_DOMAIN_PARAMETER_METHOD = "timeDomain"; public static final String OUTPUT_PARAMETER_METHOD = "outputReceiver"; @@ -366,9 +368,11 @@ public static double invokeGetSize( // public invokeStartBundle(Context c) { delegate.<@StartBundle>(c); } // ... etc ... .method(ElementMatchers.named("invokeStartBundle")) - .intercept(delegateOrNoop(clazzDescription, signature.startBundle())) + .intercept( + delegateMethodWithExtraParametersOrNoop(clazzDescription, signature.startBundle())) .method(ElementMatchers.named("invokeFinishBundle")) - .intercept(delegateOrNoop(clazzDescription, signature.finishBundle())) + .intercept( + delegateMethodWithExtraParametersOrNoop(clazzDescription, signature.finishBundle())) .method(ElementMatchers.named("invokeSetup")) .intercept(delegateOrNoop(clazzDescription, signature.setup())) .method(ElementMatchers.named("invokeTeardown")) @@ -779,6 +783,15 @@ public StackManipulation dispatch(TimestampParameter p) { TIMESTAMP_PARAMETER_METHOD, DoFn.class))); } + @Override + public StackManipulation dispatch(BundleFinalizerParameter p) { + return new StackManipulation.Compound( + pushDelegate, + MethodInvocation.invoke( + getExtraContextFactoryMethodDescription( + BUNDLE_FINALIZER_PARAMETER_METHOD, DoFn.class))); + } + @Override public StackManipulation dispatch(TimeDomainParameter p) { return new StackManipulation.Compound( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index e4ac2d8ccd849..d52872e53714e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.FinishBundle; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -53,10 +54,10 @@ public interface DoFnInvoker { void invokeSetup(); /** Invoke the {@link DoFn.StartBundle} method on the bound {@link DoFn}. */ - void invokeStartBundle(DoFn.StartBundleContext c); + void invokeStartBundle(ArgumentProvider arguments); /** Invoke the {@link DoFn.FinishBundle} method on the bound {@link DoFn}. */ - void invokeFinishBundle(DoFn.FinishBundleContext c); + void invokeFinishBundle(ArgumentProvider arguments); /** Invoke the {@link DoFn.Teardown} method on the bound {@link DoFn}. */ void invokeTeardown(); @@ -168,9 +169,15 @@ interface ArgumentProvider { /** Provide a {@link OutputReceiver} for outputting rows to the default output. */ OutputReceiver outputRowReceiver(DoFn doFn); - /** Provide a {@link MultiOutputReceiver} for outputing to the default output. */ + /** Provide a {@link MultiOutputReceiver} for outputting to the default output. */ MultiOutputReceiver taggedOutputReceiver(DoFn doFn); + /** + * Provide a {@link BundleFinalizer} for being able to register a callback after the bundle has + * been successfully persisted by the runner. + */ + BundleFinalizer bundleFinalizer(); + /** * If this is a splittable {@link DoFn}, returns the associated restriction with the current * call. @@ -330,6 +337,12 @@ public Timer timer(String timerId) { String.format("RestrictionTracker unsupported in %s", getErrorContext())); } + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException( + String.format("BundleFinalizer unsupported in %s", getErrorContext())); + } + /** * Return a human readable representation of the current call context to be used during error * reporting. @@ -455,6 +468,11 @@ public String timerId(DoFn doFn) { return delegate.timerId(doFn); } + @Override + public BundleFinalizer bundleFinalizer() { + return delegate.bundleFinalizer(); + } + @Override public String getErrorContext() { return errorContext; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index b472cc31758fc..b675596c14350 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -271,6 +271,8 @@ public ResultT match(Cases cases) { return cases.dispatch((TimerFamilyParameter) this); } else if (this instanceof TimerIdParameter) { return cases.dispatch((TimerIdParameter) this); + } else if (this instanceof BundleFinalizerParameter) { + return cases.dispatch((BundleFinalizerParameter) this); } else { throw new IllegalStateException( String.format( @@ -321,6 +323,8 @@ public interface Cases { ResultT dispatch(TimerIdParameter p); + ResultT dispatch(BundleFinalizerParameter p); + /** A base class for a visitor with a default method for cases it is not interested in. */ abstract class WithDefault implements Cases { @@ -401,6 +405,11 @@ public ResultT dispatch(RestrictionTrackerParameter p) { return dispatchDefault(p); } + @Override + public ResultT dispatch(BundleFinalizerParameter p) { + return dispatchDefault(p); + } + @Override public ResultT dispatch(StateParameter p) { return dispatchDefault(p); @@ -449,12 +458,29 @@ public ResultT dispatch(TimerFamilyParameter p) { new AutoValue_DoFnSignature_Parameter_TaggedOutputReceiverParameter(); private static final PipelineOptionsParameter PIPELINE_OPTIONS_PARAMETER = new AutoValue_DoFnSignature_Parameter_PipelineOptionsParameter(); + private static final BundleFinalizerParameter BUNDLE_FINALIZER_PARAMETER = + new AutoValue_DoFnSignature_Parameter_BundleFinalizerParameter(); /** Returns a {@link ProcessContextParameter}. */ public static ProcessContextParameter processContext() { return PROCESS_CONTEXT_PARAMETER; } + /** Returns a {@link StartBundleContextParameter}. */ + public static StartBundleContextParameter startBundleContext() { + return START_BUNDLE_CONTEXT_PARAMETER; + } + + /** Returns a {@link FinishBundleContextParameter}. */ + public static FinishBundleContextParameter finishBundleContext() { + return FINISH_BUNDLE_CONTEXT_PARAMETER; + } + + /** Returns a {@link BundleFinalizerParameter}. */ + public static BundleFinalizerParameter bundleFinalizer() { + return BUNDLE_FINALIZER_PARAMETER; + } + public static ElementParameter elementParameter(TypeDescriptor elementT) { return new AutoValue_DoFnSignature_Parameter_ElementParameter(elementT); } @@ -574,6 +600,16 @@ public abstract static class ProcessContextParameter extends Parameter { ProcessContextParameter() {} } + /** + * Descriptor for a {@link Parameter} of type {@link DoFn.BundleFinalizer}. + * + *

All such descriptors are equal. + */ + @AutoValue + public abstract static class BundleFinalizerParameter extends Parameter { + BundleFinalizerParameter() {} + } + /** * Descriptor for a {@link Parameter} of type {@link DoFn.Element}. * @@ -1041,13 +1077,23 @@ static TimerFamilyDeclaration create(String id, Field field) { /** Describes a {@link DoFn.StartBundle} or {@link DoFn.FinishBundle} method. */ @AutoValue - public abstract static class BundleMethod implements DoFnMethod { + public abstract static class BundleMethod implements MethodWithExtraParameters { /** The annotated method itself. */ @Override public abstract Method targetMethod(); - static BundleMethod create(Method targetMethod) { - return new AutoValue_DoFnSignature_BundleMethod(targetMethod); + /** Types of optional parameters of the annotated method, in the order they appear. */ + @Override + public abstract List extraParameters(); + + /** The type of window expected by this method, if any. */ + @Override + @Nullable + public abstract TypeDescriptor windowT(); + + static BundleMethod create(Method targetMethod, List extraParameters) { + /* start bundle/finish bundle currently do not get invoked on a per window basis and can't accept a BoundedWindow parameter */ + return new AutoValue_DoFnSignature_BundleMethod(targetMethod, extraParameters, null); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 007663c587d5e..d34baba15ee5b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -112,7 +112,8 @@ private DoFnSignatures() {} Parameter.TimerParameter.class, Parameter.StateParameter.class, Parameter.SideInputParameter.class, - Parameter.TimerFamilyParameter.class); + Parameter.TimerFamilyParameter.class, + Parameter.BundleFinalizerParameter.class); private static final ImmutableList> ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = @@ -126,7 +127,16 @@ private DoFnSignatures() {} Parameter.TaggedOutputReceiverParameter.class, Parameter.ProcessContextParameter.class, Parameter.RestrictionTrackerParameter.class, - Parameter.SideInputParameter.class); + Parameter.SideInputParameter.class, + Parameter.BundleFinalizerParameter.class); + + private static final ImmutableList> ALLOWED_START_BUNDLE_PARAMETERS = + ImmutableList.of( + Parameter.StartBundleContextParameter.class, Parameter.BundleFinalizerParameter.class); + + private static final ImmutableList> ALLOWED_FINISH_BUNDLE_PARAMETERS = + ImmutableList.of( + Parameter.FinishBundleContextParameter.class, Parameter.BundleFinalizerParameter.class); private static final ImmutableList> ALLOWED_ON_TIMER_PARAMETERS = ImmutableList.of( @@ -550,14 +560,16 @@ private static DoFnSignature parseSignature(Class> fnClass) if (startBundleMethod != null) { ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod); signatureBuilder.setStartBundle( - analyzeStartBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT)); + analyzeStartBundleMethod( + startBundleErrors, fnT, startBundleMethod, inputT, outputT, fnContext)); } if (finishBundleMethod != null) { ErrorReporter finishBundleErrors = errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod); signatureBuilder.setFinishBundle( - analyzeFinishBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT)); + analyzeFinishBundleMethod( + finishBundleErrors, fnT, finishBundleMethod, inputT, outputT, fnContext)); } if (setupMethod != null) { @@ -1085,6 +1097,8 @@ private static Parameter analyzeExtraParameter( TypeDescriptor outputT) { TypeDescriptor expectedProcessContextT = doFnProcessContextTypeOf(inputT, outputT); + TypeDescriptor expectedStartBundleContextT = doFnStartBundleContextTypeOf(inputT, outputT); + TypeDescriptor expectedFinishBundleContextT = doFnFinishBundleContextTypeOf(inputT, outputT); TypeDescriptor expectedOnTimerContextT = doFnOnTimerContextTypeOf(inputT, outputT); TypeDescriptor paramT = param.getType(); @@ -1115,12 +1129,26 @@ private static Parameter analyzeExtraParameter( return Parameter.sideInputParameter(paramT, sideInputId); } else if (rawType.equals(PaneInfo.class)) { return Parameter.paneInfoParameter(); + } else if (rawType.equals(DoFn.BundleFinalizer.class)) { + return Parameter.bundleFinalizer(); } else if (rawType.equals(DoFn.ProcessContext.class)) { paramErrors.checkArgument( paramT.equals(expectedProcessContextT), "ProcessContext argument must have type %s", format(expectedProcessContextT)); return Parameter.processContext(); + } else if (rawType.equals(DoFn.StartBundleContext.class)) { + paramErrors.checkArgument( + paramT.equals(expectedStartBundleContextT), + "StartBundleContext argument must have type %s", + format(expectedProcessContextT)); + return Parameter.startBundleContext(); + } else if (rawType.equals(DoFn.FinishBundleContext.class)) { + paramErrors.checkArgument( + paramT.equals(expectedFinishBundleContextT), + "FinishBundleContext argument must have type %s", + format(expectedProcessContextT)); + return Parameter.finishBundleContext(); } else if (rawType.equals(DoFn.OnTimerContext.class)) { paramErrors.checkArgument( paramT.equals(expectedOnTimerContextT), @@ -1367,16 +1395,30 @@ static DoFnSignature.BundleMethod analyzeStartBundleMethod( TypeDescriptor> fnT, Method m, TypeDescriptor inputT, - TypeDescriptor outputT) { + TypeDescriptor outputT, + FnAnalysisContext fnContext) { errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); - TypeDescriptor expectedContextT = doFnStartBundleContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); - errors.checkArgument( - params.length == 0 - || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), - "Must take a single argument of type %s", - format(expectedContextT)); - return DoFnSignature.BundleMethod.create(m); + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + for (int i = 0; i < params.length; ++i) { + Parameter extraParam = + analyzeExtraParameter( + errors, + fnContext, + methodContext, + fnT, + ParameterDescription.of( + m, i, fnT.resolveType(params[i]), Arrays.asList(m.getParameterAnnotations()[i])), + inputT, + outputT); + methodContext.addParameter(extraParam); + } + + for (Parameter parameter : methodContext.getExtraParameters()) { + checkParameterOneOf(errors, parameter, ALLOWED_START_BUNDLE_PARAMETERS); + } + + return DoFnSignature.BundleMethod.create(m, methodContext.extraParameters); } @VisibleForTesting @@ -1385,16 +1427,30 @@ static DoFnSignature.BundleMethod analyzeFinishBundleMethod( TypeDescriptor> fnT, Method m, TypeDescriptor inputT, - TypeDescriptor outputT) { + TypeDescriptor outputT, + FnAnalysisContext fnContext) { errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); - TypeDescriptor expectedContextT = doFnFinishBundleContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); - errors.checkArgument( - params.length == 0 - || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), - "Must take a single argument of type %s", - format(expectedContextT)); - return DoFnSignature.BundleMethod.create(m); + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + for (int i = 0; i < params.length; ++i) { + Parameter extraParam = + analyzeExtraParameter( + errors, + fnContext, + methodContext, + fnT, + ParameterDescription.of( + m, i, fnT.resolveType(params[i]), Arrays.asList(m.getParameterAnnotations()[i])), + inputT, + outputT); + methodContext.addParameter(extraParam); + } + + for (Parameter parameter : methodContext.getExtraParameters()) { + checkParameterOneOf(errors, parameter, ALLOWED_FINISH_BUNDLE_PARAMETERS); + } + + return DoFnSignature.BundleMethod.create(m, methodContext.extraParameters); } private static DoFnSignature.LifecycleMethod analyzeLifecycleMethod( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index e22f72920deba..d9c7742bc01fd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.transforms; +import static java.lang.Thread.sleep; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; @@ -28,6 +29,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testing.UsesBoundedSplittableParDo; +import org.apache.beam.sdk.testing.UsesBundleFinalizer; import org.apache.beam.sdk.testing.UsesParDoLifecycle; import org.apache.beam.sdk.testing.UsesSideInputs; import org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs; @@ -47,7 +50,9 @@ import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.transforms.splittabledofn.SplitResult; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.Never; @@ -73,7 +78,7 @@ import org.junit.runners.JUnit4; /** - * Tests for splittable {@link DoFn} behavior. */ @RunWith(JUnit4.class) public class SplittableDoFnTest implements Serializable { @@ -807,6 +812,80 @@ public OffsetRange getInitialRestriction() { } } + /** + * While the finalization callback hasn't been invoked, this DoFn will keep requesting + * finalization, wait one second and then checkpoint upto MAX_ATTEMPTS amount of times. Once the + * callback has been invoked, the DoFn will output the element and stop. + */ + public static class BundleFinalizingSplittableDoFn extends DoFn { + private static final long MAX_ATTEMPTS = 300; + private static final AtomicBoolean wasFinalized = new AtomicBoolean(); + + @NewTracker + public RestrictionTracker newTracker(@Restriction OffsetRange restriction) { + // Use a modified OffsetRangeTracker with disabled splitting to prevent + // parallelization of execution. + return new OffsetRangeTracker(restriction) { + @Override + public SplitResult trySplit(double fractionOfRemainder) { + return null; + } + }; + } + + @ProcessElement + public ProcessContinuation process( + @Element String element, + OutputReceiver receiver, + RestrictionTracker tracker, + BundleFinalizer bundleFinalizer) + throws InterruptedException { + if (wasFinalized.get()) { + // Claim beyond the end now that we know we have been finalized. + tracker.tryClaim(Long.MAX_VALUE); + receiver.output(element); + return stop(); + } + if (tracker.tryClaim(tracker.currentRestriction().getFrom() + 1)) { + bundleFinalizer.afterBundleCommit( + Instant.now().plus(Duration.standardSeconds(MAX_ATTEMPTS)), + () -> wasFinalized.set(true)); + // We sleep here instead of setting a resume time since the resume time doesn't need to + // be honored. + sleep(1000L); // 1 second + return resume(); + } + return stop(); + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction() { + return new OffsetRange(0, MAX_ATTEMPTS); + } + } + + @Test + @Category({ValidatesRunner.class, UsesBoundedSplittableParDo.class, UsesBundleFinalizer.class}) + public void testBundleFinalizationOccursOnBoundedSplittableDoFn() throws Exception { + @BoundedPerElement + class BoundedBundleFinalizingSplittableDoFn extends BundleFinalizingSplittableDoFn {} + Pipeline p = TestPipeline.create(); + PCollection foo = p.apply(Create.of("foo")); + PCollection res = foo.apply(ParDo.of(new BoundedBundleFinalizingSplittableDoFn())); + PAssert.that(res).containsInAnyOrder("foo"); + } + + @Test + @Category({ValidatesRunner.class, UsesUnboundedSplittableParDo.class, UsesBundleFinalizer.class}) + public void testBundleFinalizationOccursOnUnboundedSplittableDoFn() throws Exception { + @BoundedPerElement + class UnboundedBundleFinalizingSplittableDoFn extends BundleFinalizingSplittableDoFn {} + Pipeline p = TestPipeline.create(); + PCollection foo = p.apply(Create.of("foo")); + PCollection res = foo.apply(ParDo.of(new UnboundedBundleFinalizingSplittableDoFn())); + PAssert.that(res).containsInAnyOrder("foo"); + } + // TODO (https://issues.apache.org/jira/browse/BEAM-988): Test that Splittable DoFn // emits output immediately (i.e. has a pass-through trigger) regardless of input's // windowing/triggering strategy. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 8de6785af950b..902d982b99af6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; @@ -70,7 +71,6 @@ import org.mockito.AdditionalAnswers; import org.mockito.Matchers; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; /** Tests for {@link DoFnInvokers}. */ @@ -342,6 +342,10 @@ public SomeRestrictionTracker newTracker(@Restriction SomeRestriction restrictio @Test public void testDoFnWithStartBundleSetupTeardown() throws Exception { + when(mockArgumentProvider.startBundleContext(any(DoFn.class))) + .thenReturn(mockStartBundleContext); + when(mockArgumentProvider.finishBundleContext(any(DoFn.class))) + .thenReturn(mockFinishBundleContext); class MockFn extends DoFn { @ProcessElement public void processElement(ProcessContext c) {} @@ -362,8 +366,8 @@ public void after() {} MockFn fn = mock(MockFn.class); DoFnInvoker invoker = DoFnInvokers.invokerFor(fn); invoker.invokeSetup(); - invoker.invokeStartBundle(mockStartBundleContext); - invoker.invokeFinishBundle(mockFinishBundleContext); + invoker.invokeStartBundle(mockArgumentProvider); + invoker.invokeFinishBundle(mockArgumentProvider); invoker.invokeTeardown(); verify(fn).before(); verify(fn).startBundle(mockStartBundleContext); @@ -450,7 +454,7 @@ public void splitRestriction( } })) .when(fn) - .splitRestriction(eq(mockElement), same(restriction), Mockito.any()); + .splitRestriction(eq(mockElement), same(restriction), any()); when(fn.newTracker(restriction)).thenReturn(tracker); when(fn.processElement(mockProcessContext, tracker)).thenReturn(resume()); @@ -832,6 +836,9 @@ public DoFn.ProcessContext processContext(DoFn doFn) { @Test public void testStartBundleException() throws Exception { + DoFnInvoker.ArgumentProvider mockArguments = + mock(DoFnInvoker.ArgumentProvider.class); + when(mockArguments.startBundleContext(any(DoFn.class))).thenReturn(null); DoFnInvoker invoker = DoFnInvokers.invokerFor( new DoFn() { @@ -845,11 +852,14 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) {} }); thrown.expect(UserCodeException.class); thrown.expectMessage("bogus"); - invoker.invokeStartBundle(null); + invoker.invokeStartBundle(mockArguments); } @Test public void testFinishBundleException() throws Exception { + DoFnInvoker.ArgumentProvider mockArguments = + mock(DoFnInvoker.ArgumentProvider.class); + when(mockArguments.finishBundleContext(any(DoFn.class))).thenReturn(null); DoFnInvoker invoker = DoFnInvokers.invokerFor( new DoFn() { @@ -863,7 +873,7 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) {} }); thrown.expect(UserCodeException.class); thrown.expectMessage("bogus"); - invoker.invokeFinishBundle(null); + invoker.invokeFinishBundle(mockArguments); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index dc0716d5f951d..24e581b6893dd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -55,19 +55,23 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.PaneInfoParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.PipelineOptionsParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StartBundleContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TaggedOutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimeDomainParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimestampParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures.FnAnalysisContext; import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -121,10 +125,11 @@ public void process( OutputReceiver receiver, PipelineOptions options, @SideInput("tag1") String input1, - @SideInput("tag2") Integer input2) {} + @SideInput("tag2") Integer input2, + BundleFinalizer bundleFinalizer) {} }.getClass()); - assertThat(sig.processElement().extraParameters().size(), equalTo(8)); + assertThat(sig.processElement().extraParameters().size(), equalTo(9)); assertThat(sig.processElement().extraParameters().get(0), instanceOf(ElementParameter.class)); assertThat(sig.processElement().extraParameters().get(1), instanceOf(TimestampParameter.class)); assertThat(sig.processElement().extraParameters().get(2), instanceOf(WindowParameter.class)); @@ -135,6 +140,8 @@ public void process( sig.processElement().extraParameters().get(5), instanceOf(PipelineOptionsParameter.class)); assertThat(sig.processElement().extraParameters().get(6), instanceOf(SideInputParameter.class)); assertThat(sig.processElement().extraParameters().get(7), instanceOf(SideInputParameter.class)); + assertThat( + sig.processElement().extraParameters().get(8), instanceOf(BundleFinalizerParameter.class)); } @Test @@ -276,8 +283,7 @@ public void process(ProcessContext c) {} @Test public void testBadExtraContext() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Must take a single argument of type DoFn.StartBundleContext"); + thrown.expectMessage("int is not a valid context parameter"); DoFnSignatures.analyzeStartBundleMethod( errors(), @@ -286,7 +292,8 @@ public void testBadExtraContext() throws Exception { void method(DoFn.StartBundleContext c, int n) {} }.getMethod(), TypeDescriptor.of(Integer.class), - TypeDescriptor.of(String.class)); + TypeDescriptor.of(String.class), + FnAnalysisContext.create()); } @Test @@ -345,6 +352,25 @@ void startBundle() {} }.getClass()); } + @Test + public void testStartBundleWithAllParameters() throws Exception { + DoFnSignature sig = + DoFnSignatures.getSignature( + new DoFn() { + @ProcessElement + public void processElement() {} + + @StartBundle + public void startBundle( + StartBundleContext context, BundleFinalizer bundleFinalizer) {} + }.getClass()); + assertThat(sig.startBundle().extraParameters().size(), equalTo(2)); + assertThat( + sig.startBundle().extraParameters().get(0), instanceOf(StartBundleContextParameter.class)); + assertThat( + sig.startBundle().extraParameters().get(1), instanceOf(BundleFinalizerParameter.class)); + } + @Test public void testPrivateFinishBundle() throws Exception { thrown.expect(IllegalArgumentException.class); @@ -361,6 +387,26 @@ void finishBundle() {} }.getClass()); } + @Test + public void testFinishBundleWithAllParameters() throws Exception { + DoFnSignature sig = + DoFnSignatures.getSignature( + new DoFn() { + @ProcessElement + public void processElement() {} + + @FinishBundle + public void finishBundle( + FinishBundleContext context, BundleFinalizer bundleFinalizer) {} + }.getClass()); + assertThat(sig.finishBundle().extraParameters().size(), equalTo(2)); + assertThat( + sig.finishBundle().extraParameters().get(0), + instanceOf(FinishBundleContextParameter.class)); + assertThat( + sig.finishBundle().extraParameters().get(1), instanceOf(BundleFinalizerParameter.class)); + } + @Test public void testTimerIdWithWrongType() throws Exception { thrown.expect(IllegalArgumentException.class); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java index e1756dc5a52dd..650d86b648e24 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java @@ -50,6 +50,7 @@ import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints; @@ -95,7 +96,8 @@ public BeamFnDataReadRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { FnDataReceiver> consumer = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java index 8cb9d33b7529b..86deadb3286bf 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; @@ -89,7 +90,8 @@ public BeamFnDataWriteRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { BeamFnDataWriteRunner runner = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java index a06c002ac6dbc..4429bb81552c1 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException; @@ -83,7 +84,8 @@ public BoundedSourceRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { ImmutableList.Builder>> consumers = ImmutableList.builder(); for (String pCollectionId : pTransform.getOutputsMap().values()) { consumers.add(pCollectionConsumerRegistry.getMultiplexingConsumer(pCollectionId)); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java index 6bd16bed86971..feb6b6b76faba 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java @@ -40,6 +40,7 @@ import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; @@ -127,7 +128,8 @@ public PrecombineRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Get objects needed to create the runner. RehydratedComponents rehydratedComponents = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java index 702751c6bbc6c..2307b5b2e7fbe 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -68,7 +69,8 @@ public FlattenRunner createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Give each input a MultiplexingFnDataReceiver to all outputs of the flatten. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 907ec74b4ff3d..3ca5d4c482094 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -64,6 +64,7 @@ import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.TimerMap; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; @@ -71,6 +72,7 @@ import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.DelegatingArgumentProvider; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; @@ -127,169 +129,6 @@ public Map getPTransformRunnerFactories() { } } - static class Context { - final PipelineOptions pipelineOptions; - final BeamFnStateClient beamFnStateClient; - final String ptransformId; - final PTransform pTransform; - final Supplier processBundleInstructionId; - final RehydratedComponents rehydratedComponents; - final DoFn doFn; - final DoFnSignature doFnSignature; - final TupleTag mainOutputTag; - final Coder inputCoder; - final SchemaCoder schemaCoder; - final Coder keyCoder; - final SchemaCoder mainOutputSchemaCoder; - final Coder windowCoder; - final WindowingStrategy windowingStrategy; - final Map, SideInputSpec> tagToSideInputSpecMap; - Map, Coder> outputCoders; - final ParDoPayload parDoPayload; - final ListMultimap>> localNameToConsumer; - final BundleSplitListener splitListener; - - Context( - PipelineOptions pipelineOptions, - BeamFnStateClient beamFnStateClient, - String ptransformId, - PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Map coders, - Map windowingStrategies, - PCollectionConsumerRegistry pCollectionConsumerRegistry, - BundleSplitListener splitListener) { - this.pipelineOptions = pipelineOptions; - this.beamFnStateClient = beamFnStateClient; - this.ptransformId = ptransformId; - this.pTransform = pTransform; - this.processBundleInstructionId = processBundleInstructionId; - ImmutableMap.Builder, SideInputSpec> tagToSideInputSpecMapBuilder = - ImmutableMap.builder(); - try { - rehydratedComponents = - RehydratedComponents.forComponents( - RunnerApi.Components.newBuilder() - .putAllCoders(coders) - .putAllPcollections(pCollections) - .putAllWindowingStrategies(windowingStrategies) - .build()) - .withPipeline(Pipeline.create()); - parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload()); - doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload); - doFnSignature = DoFnSignatures.signatureForDoFn(doFn); - switch (pTransform.getSpec().getUrn()) { - case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: - case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN: - case PTransformTranslation.PAR_DO_TRANSFORM_URN: - mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload); - break; - case PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN: - case PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN: - case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: - mainOutputTag = - new TupleTag(Iterables.getOnlyElement(pTransform.getOutputsMap().keySet())); - break; - default: - throw new IllegalStateException( - String.format("Unknown urn: %s", pTransform.getSpec().getUrn())); - } - String mainInputTag = - Iterables.getOnlyElement( - Sets.difference( - pTransform.getInputsMap().keySet(), - Sets.union( - parDoPayload.getSideInputsMap().keySet(), - parDoPayload.getTimerSpecsMap().keySet()))); - PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag)); - inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId()); - if (inputCoder instanceof KvCoder - // TODO: Stop passing windowed value coders within PCollections. - || (inputCoder instanceof WindowedValue.WindowedValueCoder - && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) { - this.keyCoder = - inputCoder instanceof WindowedValueCoder - ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder() - : ((KvCoder) inputCoder).getKeyCoder(); - } else { - this.keyCoder = null; - } - if (inputCoder instanceof SchemaCoder - // TODO: Stop passing windowed value coders within PCollections. - || (inputCoder instanceof WindowedValue.WindowedValueCoder - && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof SchemaCoder))) { - this.schemaCoder = - inputCoder instanceof WindowedValueCoder - ? (SchemaCoder) ((WindowedValueCoder) inputCoder).getValueCoder() - : ((SchemaCoder) inputCoder); - } else { - this.schemaCoder = null; - } - - windowingStrategy = - (WindowingStrategy) - rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId()); - windowCoder = windowingStrategy.getWindowFn().windowCoder(); - - outputCoders = Maps.newHashMap(); - for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - TupleTag outputTag = new TupleTag<>(entry.getKey()); - RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue()); - Coder outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId()); - if (outputCoder instanceof WindowedValueCoder) { - outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder(); - } - outputCoders.put(outputTag, outputCoder); - } - Coder outputCoder = (Coder) outputCoders.get(mainOutputTag); - mainOutputSchemaCoder = - (outputCoder instanceof SchemaCoder) ? (SchemaCoder) outputCoder : null; - - // Build the map from tag id to side input specification - for (Map.Entry entry : - parDoPayload.getSideInputsMap().entrySet()) { - String sideInputTag = entry.getKey(); - RunnerApi.SideInput sideInput = entry.getValue(); - checkArgument( - Materializations.MULTIMAP_MATERIALIZATION_URN.equals( - sideInput.getAccessPattern().getUrn()), - "This SDK is only capable of dealing with %s materializations " - + "but was asked to handle %s for PCollectionView with tag %s.", - Materializations.MULTIMAP_MATERIALIZATION_URN, - sideInput.getAccessPattern().getUrn(), - sideInputTag); - - PCollection sideInputPCollection = - pCollections.get(pTransform.getInputsOrThrow(sideInputTag)); - WindowingStrategy sideInputWindowingStrategy = - rehydratedComponents.getWindowingStrategy( - sideInputPCollection.getWindowingStrategyId()); - tagToSideInputSpecMapBuilder.put( - new TupleTag<>(entry.getKey()), - SideInputSpec.create( - rehydratedComponents.getCoder(sideInputPCollection.getCoderId()), - sideInputWindowingStrategy.getWindowFn().windowCoder(), - PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()), - PCollectionViewTranslation.windowMappingFnFromProto( - entry.getValue().getWindowMappingFn()))); - } - } catch (IOException exn) { - throw new IllegalArgumentException("Malformed ParDoPayload", exn); - } - - ImmutableListMultimap.Builder>> - localNameToConsumerBuilder = ImmutableListMultimap.builder(); - for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - localNameToConsumerBuilder.putAll( - entry.getKey(), pCollectionConsumerRegistry.getMultiplexingConsumer(entry.getValue())); - } - localNameToConsumer = localNameToConsumerBuilder.build(); - tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build(); - this.splitListener = splitListener; - } - } - static class Factory implements PTransformRunnerFactory< FnApiDoFnRunner> { @@ -310,9 +149,11 @@ static class Factory PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { - Context context = - new Context<>( + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { + + FnApiDoFnRunner runner = + new FnApiDoFnRunner<>( pipelineOptions, beamFnStateClient, pTransformId, @@ -322,10 +163,8 @@ static class Factory coders, windowingStrategies, pCollectionConsumerRegistry, - splitListener); - - FnApiDoFnRunner runner = - new FnApiDoFnRunner<>(context); + splitListener, + bundleFinalizer); // Register the appropriate handlers. startFunctionRegistry.register(pTransformId, runner::startBundle); @@ -360,10 +199,10 @@ static class Factory pTransform.getInputsOrThrow(mainInput), pTransformId, (FnDataReceiver) mainInputConsumer); // Register as a consumer for each timer PCollection. - for (String localName : context.parDoPayload.getTimerSpecsMap().keySet()) { + for (String localName : runner.parDoPayload.getTimerSpecsMap().keySet()) { TimeDomain timeDomain = DoFnSignatures.getTimerSpecOrThrow( - context.doFnSignature.timerDeclarations().get(localName), context.doFn) + runner.doFnSignature.timerDeclarations().get(localName), runner.doFn) .getTimeDomain(); pCollectionConsumerRegistry.register( pTransform.getInputsOrThrow(localName), @@ -382,15 +221,36 @@ static class Factory ////////////////////////////////////////////////////////////////////////////////////////////////// - private final Context context; + private final PipelineOptions pipelineOptions; + private final BeamFnStateClient beamFnStateClient; + private final String pTransformId; + private final PTransform pTransform; + private final Supplier processBundleInstructionId; + private final RehydratedComponents rehydratedComponents; + private final DoFn doFn; + private final DoFnSignature doFnSignature; + private final TupleTag mainOutputTag; + private final Coder inputCoder; + private final SchemaCoder schemaCoder; + private final Coder keyCoder; + private final SchemaCoder mainOutputSchemaCoder; + private final Coder windowCoder; + private final WindowingStrategy windowingStrategy; + private final Map, SideInputSpec> tagToSideInputSpecMap; + private final Map, Coder> outputCoders; + private final ParDoPayload parDoPayload; + private final ListMultimap>> localNameToConsumer; + private final BundleSplitListener splitListener; + private final BundleFinalizer bundleFinalizer; + private final Collection>> mainOutputConsumers; private final String mainInputId; private FnApiStateAccessor stateAccessor; private final DoFnInvoker doFnInvoker; - private final DoFn.StartBundleContext startBundleContext; + private final StartBundleArgumentProvider startBundleArgumentProvider; private final ProcessBundleContext processContext; private final OnTimerContext onTimerContext; - private final DoFn.FinishBundleContext finishBundleContext; + private final FinishBundleArgumentProvider finishBundleArgumentProvider; /** * Only set for {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} and {@link @@ -427,29 +287,162 @@ static class Factory /** Only valid during {@link #processTimer}, null otherwise. */ private TimeDomain currentTimeDomain; - FnApiDoFnRunner(Context context) { - this.context = context; + FnApiDoFnRunner( + PipelineOptions pipelineOptions, + BeamFnStateClient beamFnStateClient, + String pTransformId, + PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Map windowingStrategies, + PCollectionConsumerRegistry pCollectionConsumerRegistry, + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { + this.pipelineOptions = pipelineOptions; + this.beamFnStateClient = beamFnStateClient; + this.pTransformId = pTransformId; + this.pTransform = pTransform; + this.processBundleInstructionId = processBundleInstructionId; + ImmutableMap.Builder, SideInputSpec> tagToSideInputSpecMapBuilder = + ImmutableMap.builder(); try { - this.mainInputId = ParDoTranslation.getMainInputName(context.pTransform); + rehydratedComponents = + RehydratedComponents.forComponents( + RunnerApi.Components.newBuilder() + .putAllCoders(coders) + .putAllPcollections(pCollections) + .putAllWindowingStrategies(windowingStrategies) + .build()) + .withPipeline(Pipeline.create()); + parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload()); + doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload); + doFnSignature = DoFnSignatures.signatureForDoFn(doFn); + switch (pTransform.getSpec().getUrn()) { + case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: + case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN: + case PTransformTranslation.PAR_DO_TRANSFORM_URN: + mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload); + break; + case PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN: + case PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN: + case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: + mainOutputTag = + new TupleTag(Iterables.getOnlyElement(pTransform.getOutputsMap().keySet())); + break; + default: + throw new IllegalStateException( + String.format("Unknown urn: %s", pTransform.getSpec().getUrn())); + } + String mainInputTag = + Iterables.getOnlyElement( + Sets.difference( + pTransform.getInputsMap().keySet(), + Sets.union( + parDoPayload.getSideInputsMap().keySet(), + parDoPayload.getTimerSpecsMap().keySet()))); + PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag)); + inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId()); + if (inputCoder instanceof KvCoder + // TODO: Stop passing windowed value coders within PCollections. + || (inputCoder instanceof WindowedValue.WindowedValueCoder + && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) { + this.keyCoder = + inputCoder instanceof WindowedValueCoder + ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder() + : ((KvCoder) inputCoder).getKeyCoder(); + } else { + this.keyCoder = null; + } + if (inputCoder instanceof SchemaCoder + // TODO: Stop passing windowed value coders within PCollections. + || (inputCoder instanceof WindowedValue.WindowedValueCoder + && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof SchemaCoder))) { + this.schemaCoder = + inputCoder instanceof WindowedValueCoder + ? (SchemaCoder) ((WindowedValueCoder) inputCoder).getValueCoder() + : ((SchemaCoder) inputCoder); + } else { + this.schemaCoder = null; + } + + windowingStrategy = + (WindowingStrategy) + rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId()); + windowCoder = windowingStrategy.getWindowFn().windowCoder(); + + outputCoders = Maps.newHashMap(); + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { + TupleTag outputTag = new TupleTag<>(entry.getKey()); + RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue()); + Coder outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId()); + if (outputCoder instanceof WindowedValueCoder) { + outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder(); + } + outputCoders.put(outputTag, outputCoder); + } + Coder outputCoder = (Coder) outputCoders.get(mainOutputTag); + mainOutputSchemaCoder = + (outputCoder instanceof SchemaCoder) ? (SchemaCoder) outputCoder : null; + + // Build the map from tag id to side input specification + for (Map.Entry entry : + parDoPayload.getSideInputsMap().entrySet()) { + String sideInputTag = entry.getKey(); + RunnerApi.SideInput sideInput = entry.getValue(); + checkArgument( + Materializations.MULTIMAP_MATERIALIZATION_URN.equals( + sideInput.getAccessPattern().getUrn()), + "This SDK is only capable of dealing with %s materializations " + + "but was asked to handle %s for PCollectionView with tag %s.", + Materializations.MULTIMAP_MATERIALIZATION_URN, + sideInput.getAccessPattern().getUrn(), + sideInputTag); + + PCollection sideInputPCollection = + pCollections.get(pTransform.getInputsOrThrow(sideInputTag)); + WindowingStrategy sideInputWindowingStrategy = + rehydratedComponents.getWindowingStrategy( + sideInputPCollection.getWindowingStrategyId()); + tagToSideInputSpecMapBuilder.put( + new TupleTag<>(entry.getKey()), + SideInputSpec.create( + rehydratedComponents.getCoder(sideInputPCollection.getCoderId()), + sideInputWindowingStrategy.getWindowFn().windowCoder(), + PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()), + PCollectionViewTranslation.windowMappingFnFromProto( + entry.getValue().getWindowMappingFn()))); + } + } catch (IOException exn) { + throw new IllegalArgumentException("Malformed ParDoPayload", exn); + } + + ImmutableListMultimap.Builder>> + localNameToConsumerBuilder = ImmutableListMultimap.builder(); + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { + localNameToConsumerBuilder.putAll( + entry.getKey(), pCollectionConsumerRegistry.getMultiplexingConsumer(entry.getValue())); + } + localNameToConsumer = localNameToConsumerBuilder.build(); + tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build(); + this.splitListener = splitListener; + this.bundleFinalizer = bundleFinalizer; + + try { + this.mainInputId = ParDoTranslation.getMainInputName(pTransform); } catch (IOException e) { throw new RuntimeException(e); } this.mainOutputConsumers = (Collection>>) - (Collection) context.localNameToConsumer.get(context.mainOutputTag.getId()); - this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.parDoPayload); - this.sideInputMapping = ParDoTranslation.getSideInputMapping(context.parDoPayload); - this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn); + (Collection) localNameToConsumer.get(mainOutputTag.getId()); + this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(parDoPayload); + this.sideInputMapping = ParDoTranslation.getSideInputMapping(parDoPayload); + this.doFnInvoker = DoFnInvokers.invokerFor(doFn); this.doFnInvoker.invokeSetup(); - this.startBundleContext = - this.context.doFn.new StartBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; - } - }; - switch (context.pTransform.getSpec().getUrn()) { + this.startBundleArgumentProvider = new StartBundleArgumentProvider(); + switch (pTransform.getSpec().getUrn()) { case PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN: // OutputT == RestrictionT this.processContext = @@ -514,36 +507,12 @@ public Instant timestamp(DoFn doFn) { break; default: throw new IllegalStateException( - String.format("Unknown URN %s", context.pTransform.getSpec().getUrn())); + String.format("Unknown URN %s", pTransform.getSpec().getUrn())); } this.onTimerContext = new OnTimerContext(); - this.finishBundleContext = - this.context.doFn.new FinishBundleContext() { - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; - } - - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - outputTo( - mainOutputConsumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - - @Override - public void output( - TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - }; - - switch (context.pTransform.getSpec().getUrn()) { + this.finishBundleArgumentProvider = new FinishBundleArgumentProvider(); + + switch (pTransform.getSpec().getUrn()) { case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN: this.convertSplitResultToWindowedSplitResult = (splitResult) -> @@ -606,7 +575,7 @@ public Object restriction() { throw new IllegalStateException( String.format( "Unimplemented split conversion handler for %s.", - context.pTransform.getSpec().getUrn())); + pTransform.getSpec().getUrn())); }; } } @@ -614,17 +583,17 @@ public Object restriction() { public void startBundle() { this.stateAccessor = new FnApiStateAccessor( - context.pipelineOptions, - context.ptransformId, - context.processBundleInstructionId, - context.tagToSideInputSpecMap, - context.beamFnStateClient, - context.keyCoder, - (Coder) context.windowCoder, + pipelineOptions, + pTransformId, + processBundleInstructionId, + tagToSideInputSpecMap, + beamFnStateClient, + keyCoder, + (Coder) windowCoder, () -> MoreObjects.firstNonNull(currentElement, currentTimer), () -> currentWindow); - doFnInvoker.invokeStartBundle(startBundleContext); + doFnInvoker.invokeStartBundle(startBundleArgumentProvider); } public void processElementForParDo(WindowedValue elem) { @@ -729,8 +698,7 @@ public void processElementForElementAndRestriction(WindowedValue>>> consumers = - (Collection) context.localNameToConsumer.get(timerId); + (Collection) localNameToConsumer.get(timerId); if (currentOutputTimestamp == null) { if (TimeDomain.EVENT_TIME.equals(timeDomain)) { @@ -965,6 +931,83 @@ public org.apache.beam.sdk.state.Timer get(String timerId) { } } + private class StartBundleArgumentProvider extends BaseArgumentProvider { + private class Context extends DoFn.StartBundleContext { + Context() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + } + + private final Context context = new Context(); + + @Override + public DoFn.StartBundleContext startBundleContext(DoFn doFn) { + return context; + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; + } + + @Override + public String getErrorContext() { + return "FnApiDoFnRunner/StartBundle"; + } + } + + private class FinishBundleArgumentProvider extends BaseArgumentProvider { + private class Context extends DoFn.FinishBundleContext { + Context() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + outputTo( + mainOutputConsumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + + @Override + public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + } + + private final Context context = new Context(); + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return context; + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; + } + + @Override + public String getErrorContext() { + return "FnApiDoFnRunner/FinishBundle"; + } + } + /** * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}. */ @@ -972,7 +1015,7 @@ private class ProcessBundleContext extends DoFn.ProcessContext implements DoFnInvoker.ArgumentProvider { private ProcessBundleContext() { - context.doFn.super(); + doFn.super(); } @Override @@ -1043,12 +1086,17 @@ public OutputReceiver outputReceiver(DoFn doFn) { @Override public OutputReceiver outputRowReceiver(DoFn doFn) { - return DoFnOutputReceivers.rowReceiver(this, null, context.mainOutputSchemaCoder); + return DoFnOutputReceivers.rowReceiver(this, null, mainOutputSchemaCoder); } @Override public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedMultiReceiver(this, context.outputCoders); + return DoFnOutputReceivers.windowedMultiReceiver(this, outputCoders); + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; } @Override @@ -1069,11 +1117,11 @@ public DoFn.OnTimerContext onTimerContext(DoFn @Override public State state(String stateId, boolean alwaysFetched) { - StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); StateSpec spec; try { - spec = (StateSpec) stateDeclaration.field().get(context.doFn); + spec = (StateSpec) stateDeclaration.field().get(doFn); } catch (IllegalAccessException e) { throw new RuntimeException(e); } @@ -1103,12 +1151,12 @@ public TimerMap timerFamily(String tagId) { @Override public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; + return pipelineOptions; } @Override public PipelineOptions pipelineOptions() { - return context.pipelineOptions; + return pipelineOptions; } @Override @@ -1131,7 +1179,7 @@ public void output(TupleTag tag, T output) { @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); + (Collection) localNameToConsumer.get(tag.getId()); if (consumers == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } @@ -1166,110 +1214,130 @@ public void updateWatermark(Instant watermark) { } /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer @OnTimer}. */ - private class OnTimerContext extends DoFn.OnTimerContext - implements DoFnInvoker.ArgumentProvider { + private class OnTimerContext extends BaseArgumentProvider { + private class Context extends DoFn.OnTimerContext { + private Context() { + doFn.super(); + } - private OnTimerContext() { - context.doFn.super(); - } + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } - @Override - public BoundedWindow window() { - return currentWindow; - } + @Override + public BoundedWindow window() { + return currentWindow; + } - @Override - public PaneInfo paneInfo(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access paneInfo outside of @ProcessElement methods."); - } + @Override + public void output(OutputT output) { + outputTo( + mainOutputConsumers, + WindowedValue.of( + output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.StartBundleContext startBundleContext(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + checkArgument( + !currentTimer.getTimestamp().isAfter(timestamp), + "Output time %s can not be before timer timestamp %s.", + timestamp, + currentTimer.getTimestamp()); + outputTo( + mainOutputConsumers, + WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.FinishBundleContext finishBundleContext( - DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } + @Override + public void output(TupleTag tag, T output) { + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo( + consumers, + WindowedValue.of( + output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public DoFn.ProcessContext processContext(DoFn doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } + @Override + public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + checkArgument( + !currentTimer.getTimestamp().isAfter(timestamp), + "Output time %s can not be before timer timestamp %s.", + timestamp, + currentTimer.getTimestamp()); + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + } - @Override - public InputT element(DoFn doFn) { - throw new UnsupportedOperationException("Element parameters are not supported."); - } + @Override + public TimeDomain timeDomain() { + return currentTimeDomain; + } - @Override - public InputT sideInput(String tagId) { - throw new UnsupportedOperationException("SideInput parameters are not supported."); - } + @Override + public Instant fireTimestamp() { + return currentTimer.getValue().getValue().getTimestamp(); + } - @Override - public Object schemaElement(int index) { - throw new UnsupportedOperationException("Element parameters are not supported."); + @Override + public Instant timestamp() { + return currentTimer.getTimestamp(); + } } + private final Context context = new Context(); + @Override - public Instant timestamp(DoFn doFn) { - return timestamp(); + public BoundedWindow window() { + return currentWindow; } @Override - public String timerId(DoFn doFn) { - throw new UnsupportedOperationException("TimerId parameters are not supported."); + public Instant timestamp(DoFn doFn) { + return currentTimer.getTimestamp(); } @Override public TimeDomain timeDomain(DoFn doFn) { - return timeDomain(); + return currentTimeDomain; } @Override public OutputReceiver outputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedReceiver(this, null); + return DoFnOutputReceivers.windowedReceiver(context, null); } @Override public OutputReceiver outputRowReceiver(DoFn doFn) { - return DoFnOutputReceivers.rowReceiver(this, null, context.mainOutputSchemaCoder); + return DoFnOutputReceivers.rowReceiver(context, null, mainOutputSchemaCoder); } @Override public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { - return DoFnOutputReceivers.windowedMultiReceiver(this); - } - - @Override - public Object restriction() { - throw new UnsupportedOperationException("Restriction parameters are not supported."); + return DoFnOutputReceivers.windowedMultiReceiver(context); } @Override public DoFn.OnTimerContext onTimerContext(DoFn doFn) { - return this; - } - - @Override - public RestrictionTracker restrictionTracker() { - throw new UnsupportedOperationException("RestrictionTracker parameters are not supported."); + return context; } @Override public State state(String stateId, boolean alwaysFetched) { - StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); StateSpec spec; try { - spec = (StateSpec) stateDeclaration.field().get(context.doFn); + spec = (StateSpec) stateDeclaration.field().get(doFn); } catch (IllegalAccessException e) { throw new RuntimeException(e); } @@ -1294,78 +1362,17 @@ public org.apache.beam.sdk.state.Timer timer(String timerId) { @Override public TimerMap timerFamily(String tagId) { // TODO: implement timerFamily - throw new UnsupportedOperationException("TimerFamily parameters are not supported."); - } - - @Override - public PipelineOptions getPipelineOptions() { - return context.pipelineOptions; + return super.timerFamily(tagId); } @Override public PipelineOptions pipelineOptions() { - return context.pipelineOptions; - } - - @Override - public void output(OutputT output) { - outputTo( - mainOutputConsumers, - WindowedValue.of(output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void outputWithTimestamp(OutputT output, Instant timestamp) { - checkArgument( - !currentTimer.getTimestamp().isAfter(timestamp), - "Output time %s can not be before timer timestamp %s.", - timestamp, - currentTimer.getTimestamp()); - outputTo( - mainOutputConsumers, - WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void output(TupleTag tag, T output) { - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo( - consumers, - WindowedValue.of(output, currentTimer.getTimestamp(), currentWindow, PaneInfo.NO_FIRING)); - } - - @Override - public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { - checkArgument( - !currentTimer.getTimestamp().isAfter(timestamp), - "Output time %s can not be before timer timestamp %s.", - timestamp, - currentTimer.getTimestamp()); - Collection>> consumers = - (Collection) context.localNameToConsumer.get(tag.getId()); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, WindowedValue.of(output, timestamp, currentWindow, PaneInfo.NO_FIRING)); + return pipelineOptions; } @Override - public TimeDomain timeDomain() { - return currentTimeDomain; - } - - @Override - public Instant fireTimestamp() { - return currentTimer.getValue().getValue().getTimestamp(); - } - - @Override - public Instant timestamp() { - return currentTimer.getTimestamp(); + public String getErrorContext() { + return "FnApiDoFnRunner/OnTimer"; } } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 579d44707de16..ac79161edda70 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -23,6 +23,7 @@ import java.util.function.Function; import org.apache.beam.fn.harness.control.AddHarnessIdInterceptor; import org.apache.beam.fn.harness.control.BeamFnControlClient; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler; import org.apache.beam.fn.harness.control.ProcessBundleHandler; import org.apache.beam.fn.harness.control.RegisterHandler; import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; @@ -190,10 +191,20 @@ public static void main( new BeamFnStateGrpcClientCache( idGenerator, channelFactory::forDescriptor, outboundObserverFactory); + FinalizeBundleHandler finalizeBundleHandler = + new FinalizeBundleHandler(options.as(GcsOptions.class).getExecutorService()); + ProcessBundleHandler processBundleHandler = new ProcessBundleHandler( - options, fnApiRegistry::getById, beamFnDataMultiplexer, beamFnStateGrpcClientCache); + options, + fnApiRegistry::getById, + beamFnDataMultiplexer, + beamFnStateGrpcClientCache, + finalizeBundleHandler); handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER, fnApiRegistry::register); + handlers.put( + BeamFnApi.InstructionRequest.RequestCase.FINALIZE_BUNDLE, + finalizeBundleHandler::finalizeBundle); // TODO(BEAM-6597): Collect MonitoringInfos in ProcessBundleProgressResponses. handlers.put( BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java index 0cd5dcae08542..d680b471f46a9 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.function.ThrowingFunction; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -110,7 +111,8 @@ public Mapper createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { FnDataReceiver> consumer = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java index 389331f87c218..c21e9e2d22579 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java @@ -32,6 +32,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; /** A factory able to instantiate an appropriate handler for a given PTransform. */ public interface PTransformRunnerFactory { @@ -59,6 +60,9 @@ public interface PTransformRunnerFactory { * @param finishFunctionRegistry A class to register a finish bundle handler with. * @param addTearDownFunction A consumer to register a tear down handler with. * @param splitListener A listener to be invoked when the PTransform splits itself. + * @param bundleFinalizer Register callbacks that will be invoked when the runner completes the + * bundle. The specified instant provides the timeout on how long the finalization callback is + * valid for. */ T createRunnerForPTransform( PipelineOptions pipelineOptions, @@ -74,7 +78,8 @@ T createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException; /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java new file mode 100644 index 0000000000000..985a9cd00baca --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.control; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.PriorityQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A bundle finalization handler that expires entries after a specified amount of time. + * + *

Callers should register new callbacks via {@link #registerCallbacks} and fire existing + * callbacks using {@link #finalizeBundle}. + * + *

See Apache Beam Portability API: How to + * Finalize Bundles for further details. + */ +public class FinalizeBundleHandler { + + /** A {@link BundleFinalizer.Callback} and expiry time pair. */ + @AutoValue + abstract static class CallbackRegistration { + public static CallbackRegistration create( + Instant expiryTime, BundleFinalizer.Callback callback) { + return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback); + } + + public abstract Instant getExpiryTime(); + + public abstract BundleFinalizer.Callback getCallback(); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(FinalizeBundleHandler.class); + private final ConcurrentMap> bundleFinalizationCallbacks; + private final PriorityQueue> cleanUpQueue; + private final Future cleanUpResult; + + public FinalizeBundleHandler(ExecutorService executorService) { + this.bundleFinalizationCallbacks = new ConcurrentHashMap<>(); + this.cleanUpQueue = + new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp)); + this.cleanUpResult = + executorService.submit( + (Callable) + () -> { + while (true) { + synchronized (cleanUpQueue) { + TimestampedValue expiryTime = cleanUpQueue.peek(); + + // Wait until we have at least one element. We are notified on each element + // being added. + while (expiryTime == null) { + cleanUpQueue.wait(); + expiryTime = cleanUpQueue.peek(); + } + + // Wait until the current time has past the expiry time for the head of the + // queue. + // We are notified on each element being added. + Instant now = Instant.now(); + while (expiryTime.getTimestamp().isAfter(now)) { + Duration timeDifference = new Duration(now, expiryTime.getTimestamp()); + cleanUpQueue.wait(timeDifference.getMillis()); + expiryTime = cleanUpQueue.peek(); + now = Instant.now(); + } + + bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue()); + } + } + }); + } + + public void registerCallbacks(String bundleId, Collection callbacks) { + if (callbacks.isEmpty()) { + return; + } + + Collection priorCallbacks = + bundleFinalizationCallbacks.putIfAbsent(bundleId, callbacks); + checkState( + priorCallbacks == null, + "Expected to not have any past callbacks for bundle %s but found %s.", + bundleId, + priorCallbacks); + long expiryTimeMillis = Long.MIN_VALUE; + for (CallbackRegistration callback : callbacks) { + expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis()); + } + synchronized (cleanUpQueue) { + cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis))); + cleanUpQueue.notify(); + } + } + + public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.InstructionRequest request) + throws Exception { + String bundleId = request.getFinalizeBundle().getInstructionId(); + + Collection callbacks = bundleFinalizationCallbacks.remove(bundleId); + + if (callbacks == null) { + // We have already processed the callbacks on a prior bundle finalization attempt + return BeamFnApi.InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); + } + + Collection failures = new ArrayList<>(); + for (CallbackRegistration callback : callbacks) { + try { + callback.getCallback().onBundleSuccess(); + } catch (Exception e) { + failures.add(e); + } + } + if (!failures.isEmpty()) { + Exception e = + new Exception( + String.format("Failed to handle bundle finalization for bundle %s.", bundleId)); + for (Exception failure : failures) { + e.addSuppressed(failure); + } + throw e; + } + + return BeamFnApi.InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 00882b5607812..48e24509bf96d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -21,6 +21,7 @@ import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -36,6 +37,7 @@ import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry; import org.apache.beam.fn.harness.data.PTransformFunctionRegistry; @@ -63,18 +65,21 @@ import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,6 +126,7 @@ public class ProcessBundleHandler { private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; + private final FinalizeBundleHandler finalizeBundleHandler; private final Map urnToPTransformRunnerFactoryMap; private final PTransformRunnerFactory defaultPTransformRunnerFactory; @VisibleForTesting final BundleProcessorCache bundleProcessorCache; @@ -129,12 +135,14 @@ public ProcessBundleHandler( PipelineOptions options, Function fnApiRegistry, BeamFnDataClient beamFnDataClient, - BeamFnStateGrpcClientCache beamFnStateGrpcClientCache) { + BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, + FinalizeBundleHandler finalizeBundleHandler) { this( options, fnApiRegistry, beamFnDataClient, beamFnStateGrpcClientCache, + finalizeBundleHandler, REGISTERED_RUNNER_FACTORIES, new BundleProcessorCache()); } @@ -145,12 +153,14 @@ public ProcessBundleHandler( Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, + FinalizeBundleHandler finalizeBundleHandler, Map urnToPTransformRunnerFactoryMap, BundleProcessorCache bundleProcessorCache) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache; + this.finalizeBundleHandler = finalizeBundleHandler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = new UnknownPTransformRunnerFactory(urnToPTransformRunnerFactoryMap.keySet()); @@ -170,7 +180,8 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { // Recursively ensure that all consumers of the output PCollection have been created. @@ -192,7 +203,8 @@ private void createRunnerAndConsumersForPTransformRecursively( startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener); + splitListener, + bundleFinalizer); } } @@ -225,7 +237,8 @@ private void createRunnerAndConsumersForPTransformRecursively( startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener); + splitListener, + bundleFinalizer); processedPTransformIds.add(pTransformId); } } @@ -298,6 +311,14 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) { response.addMonitoringInfos(mi); } + + if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) { + finalizeBundleHandler.registerCallbacks( + bundleProcessor.getInstructionId(), + ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations())); + response.setRequiresFinalization(true); + } + bundleProcessorCache.release( request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); } @@ -382,6 +403,16 @@ private BundleProcessor createBundleProcessor( } }; + Collection bundleFinalizationCallbackRegistrations = new ArrayList<>(); + BundleFinalizer bundleFinalizer = + new BundleFinalizer() { + @Override + public void afterBundleCommit(Instant callbackExpiry, Callback callback) { + bundleFinalizationCallbackRegistrations.add( + CallbackRegistration.create(callbackExpiry, callback)); + } + }; + BundleProcessor bundleProcessor = BundleProcessor.create( startFunctionRegistry, @@ -392,7 +423,8 @@ private BundleProcessor createBundleProcessor( metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbackRegistrations); // Create a BeamFnStateClient for (Map.Entry entry : @@ -420,7 +452,8 @@ private BundleProcessor createBundleProcessor( startFunctionRegistry, finishFunctionRegistry, tearDownFunctions::add, - splitListener); + splitListener, + bundleFinalizer); } return bundleProcessor; } @@ -517,7 +550,8 @@ public static BundleProcessor create( MetricsContainerStepMap metricsContainerRegistry, ExecutionStateTracker stateTracker, HandleStateCallsForBundle beamFnStateClient, - QueueingBeamFnDataClient queueingClient) { + QueueingBeamFnDataClient queueingClient, + Collection bundleFinalizationCallbackRegistrations) { return new AutoValue_ProcessBundleHandler_BundleProcessor( startFunctionRegistry, finishFunctionRegistry, @@ -527,7 +561,8 @@ public static BundleProcessor create( metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbackRegistrations); } private String instructionId; @@ -550,6 +585,8 @@ public static BundleProcessor create( abstract QueueingBeamFnDataClient getQueueingClient(); + abstract Collection getBundleFinalizationCallbackRegistrations(); + String getInstructionId() { return this.instructionId; } @@ -566,6 +603,7 @@ void reset() { getMetricsContainerRegistry().reset(); getStateTracker().reset(); ExecutionStateSampler.instance().reset(); + getBundleFinalizationCallbackRegistrations().clear(); } } @@ -659,7 +697,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer tearDownFunctions, - BundleSplitListener splitListener) { + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) { String message = String.format( "No factory registered for %s, known factories %s", diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java index c27400f5f05fd..f00c4f941e8ae 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java @@ -206,7 +206,8 @@ public Coder windowCoder() { null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); WindowedValue value = WindowedValue.of( diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java index dc319d7312628..8b9daab05656d 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java @@ -175,7 +175,8 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(teardownFunctions, empty()); @@ -476,6 +477,7 @@ private BeamFnDataReadRunner createReadRunner( startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index dade2d9094899..5085fe9320480 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -149,7 +149,8 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(teardownFunctions, empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java index 0f461d20c4650..e419783da7ae3 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java @@ -173,7 +173,8 @@ public void testCreatingAndProcessingSourceFromFactory() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); // This is testing a deprecated way of running sources and should be removed // once all source definitions are instead propagated along the input edge. diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java index 7f5be398a148d..0b0002f72b869 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java @@ -157,6 +157,7 @@ public void testPrecombine() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -232,6 +233,7 @@ public void testMergeAccumulators() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -295,6 +297,7 @@ public void testExtractOutputs() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -358,6 +361,7 @@ public void testCombineGroupedValues() throws Exception { startFunctionRegistry, finishFunctionRegistry, null, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java index dc78ea2b22c84..d9f55f76015cc 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java @@ -92,7 +92,8 @@ public void testCreatingAndProcessingDoFlatten() throws Exception { null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); mainOutputValues.clear(); assertThat( @@ -160,7 +161,8 @@ public void testFlattenWithDuplicateInputCollectionProducesMultipleOutputs() thr null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ null, /* tearDownRegistry */ - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); mainOutputValues.clear(); assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget")); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index cf75bee2302e1..d5e8b6df6b0e8 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -239,7 +239,8 @@ public void testUsingUserState() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -414,7 +415,8 @@ public void testBasicWithSideInputsAndOutputs() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -552,7 +554,8 @@ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -667,7 +670,8 @@ public void testUsingMetrics() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -878,7 +882,8 @@ public void testTimers() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1153,7 +1158,8 @@ public void split( primarySplits.addAll(primaryRoots); residualSplits.addAll(residualRoots); } - }); + }, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1259,7 +1265,8 @@ public void testProcessElementForPairWithRestriction() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* bundleSplitListener */); + null /* bundleSplitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); @@ -1348,7 +1355,8 @@ public void testProcessElementForSplitAndSizeRestriction() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* bundleSplitListener */); + null /* bundleSplitListener */, + null /* bundleFinalizer */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); mainOutputValues.clear(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java index 956e295dd31e7..7d9abbbde6885 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java @@ -93,7 +93,8 @@ public void testValueOnlyMapping() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); @@ -138,7 +139,8 @@ public void testFullWindowedValueMapping() throws Exception { startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); @@ -182,7 +184,8 @@ public void testFullWindowedValueMappingWithCompressedWindow() throws Exception startFunctionRegistry, finishFunctionRegistry, teardownFunctions::add, - null /* splitListener */); + null /* splitListener */, + null /* bundleFinalizer */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java new file mode 100644 index 0000000000000..a760d22b78af0 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.control; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.StringContains.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FinalizeBundleHandler}. */ +@RunWith(JUnit4.class) +public class FinalizeBundleHandlerTest { + private static final String INSTRUCTION_ID = "instructionId"; + private static final InstructionResponse SUCCESSFUL_RESPONSE = + InstructionResponse.newBuilder() + .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()) + .build(); + + @Test + public void testRegistrationAndCallback() throws Exception { + AtomicBoolean wasCalled1 = new AtomicBoolean(); + AtomicBoolean wasCalled2 = new AtomicBoolean(); + List callbacks = new ArrayList<>(); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), () -> wasCalled1.set(true))); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), () -> wasCalled2.set(true))); + + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + handler.registerCallbacks("test", callbacks); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + assertTrue(wasCalled1.get()); + assertTrue(wasCalled2.get()); + } + + @Test + public void testFinalizationIgnoresMissingBundleIds() throws Exception { + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + } + + @Test + public void testFinalizationContinuesToNextCallbackEvenInFailure() throws Exception { + List callbacks = new ArrayList<>(); + AtomicBoolean wasCalled1 = new AtomicBoolean(); + AtomicBoolean wasCalled2 = new AtomicBoolean(); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), + () -> { + wasCalled1.set(true); + throw new Exception("testException1"); + })); + callbacks.add( + CallbackRegistration.create( + Instant.now().plus(Duration.standardHours(1)), + () -> { + wasCalled2.set(true); + throw new Exception("testException2"); + })); + + FinalizeBundleHandler handler = new FinalizeBundleHandler(Executors.newCachedThreadPool()); + handler.registerCallbacks("test", callbacks); + + try { + handler.finalizeBundle(requestFor("test")); + fail(); + } catch (Exception e) { + assertThat(e.getMessage(), containsString("Failed to handle bundle finalization for bundle")); + assertEquals(2, e.getSuppressed().length); + assertTrue(wasCalled1.get()); + assertTrue(wasCalled2.get()); + } + } + + private static InstructionRequest requestFor(String bundleId) { + return InstructionRequest.newBuilder() + .setInstructionId(INSTRUCTION_ID) + .setFinalizeBundle(FinalizeBundleRequest.newBuilder().setInstructionId(bundleId).build()) + .build(); + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index f6e65537b6e11..f9aeb6b485f3e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -21,12 +21,15 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -34,6 +37,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -42,6 +46,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; +import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor; import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -71,6 +76,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.DoFnWithExecutionInformation; @@ -80,9 +86,11 @@ import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -213,6 +221,11 @@ QueueingBeamFnDataClient getQueueingClient() { return wrappedBundleProcessor.getQueueingClient(); } + @Override + Collection getBundleFinalizationCallbackRegistrations() { + return wrappedBundleProcessor.getBundleFinalizationCallbackRegistrations(); + } + @Override void reset() { resetCnt++; @@ -269,7 +282,8 @@ public void testOrderOfStartAndFinishCalls() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { transformsProcessed.add(pTransform); startFunctionRegistry.register( pTransformId, @@ -292,6 +306,7 @@ public void testOrderOfStartAndFinishCalls() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateClient */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, startFinishRecorder, DATA_OUTPUT_URN, startFinishRecorder), @@ -398,7 +413,8 @@ public void testOrderOfSetupTeardownCalls() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> null); + splitListener, + bundleFinalizer) -> null); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -406,6 +422,7 @@ public void testOrderOfSetupTeardownCalls() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateClient */, + null /* finalizeBundleHandler */, urnToPTransformRunnerFactoryMap, new BundleProcessorCache()); @@ -451,6 +468,7 @@ public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (pipelineOptions, @@ -466,7 +484,8 @@ public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> null), + splitListener, + bundleFinalizer) -> null), new TestBundleProcessorCache()); assertThat(TestBundleProcessor.resetCnt, equalTo(0)); @@ -510,6 +529,7 @@ public void testBundleProcessorReset() { PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class); PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class); Multimap allResiduals = mock(Multimap.class); + Collection bundleFinalizationCallbacks = mock(Collection.class); PCollectionConsumerRegistry pCollectionConsumerRegistry = mock(PCollectionConsumerRegistry.class); MetricsContainerStepMap metricsContainerRegistry = mock(MetricsContainerStepMap.class); @@ -527,7 +547,8 @@ public void testBundleProcessorReset() { metricsContainerRegistry, stateTracker, beamFnStateClient, - queueingClient); + queueingClient, + bundleFinalizationCallbacks); bundleProcessor.reset(); verify(startFunctionRegistry, times(1)).reset(); @@ -536,6 +557,7 @@ public void testBundleProcessorReset() { verify(pCollectionConsumerRegistry, times(1)).reset(); verify(metricsContainerRegistry, times(1)).reset(); verify(stateTracker, times(1)).reset(); + verify(bundleFinalizationCallbacks, times(1)).clear(); } @Test @@ -556,6 +578,7 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (pipelineOptions, @@ -571,7 +594,8 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); throw new IllegalStateException("TestException"); @@ -584,6 +608,74 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { .build()); } + @Test + public void testBundleFinalizationIsPropagated() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms( + "2L", + RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + FinalizeBundleHandler mockFinalizeBundleHandler = mock(FinalizeBundleHandler.class); + BundleFinalizer.Callback mockCallback = mock(BundleFinalizer.Callback.class); + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + null /* beamFnStateGrpcClientCache */, + mockFinalizeBundleHandler, + ImmutableMap.of( + DATA_INPUT_URN, + (PTransformRunnerFactory) + (pipelineOptions, + beamFnDataClient, + beamFnStateClient, + pTransformId, + pTransform, + processBundleInstructionId, + pCollections, + coders, + windowingStrategies, + pCollectionConsumerRegistry, + startFunctionRegistry, + finishFunctionRegistry, + addTearDownFunction, + splitListener, + bundleFinalizer) -> { + startFunctionRegistry.register( + pTransformId, + () -> + bundleFinalizer.afterBundleCommit( + Instant.ofEpochMilli(42L), mockCallback)); + return null; + }), + new BundleProcessorCache()); + BeamFnApi.InstructionResponse.Builder response = + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId("2L") + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) + .build()); + + assertTrue(response.getProcessBundle().getRequiresFinalization()); + verify(mockFinalizeBundleHandler) + .registerCallbacks( + eq("2L"), + argThat( + (Collection arg) -> { + CallbackRegistration registration = Iterables.getOnlyElement(arg); + assertEquals(Instant.ofEpochMilli(42L), registration.getExpiryTime()); + assertSame(mockCallback, registration.getCallback()); + return true; + })); + } + @Test public void testPTransformStartExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = @@ -602,6 +694,7 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (PTransformRunnerFactory) @@ -618,7 +711,8 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); startFunctionRegistry.register( @@ -656,6 +750,7 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, (PTransformRunnerFactory) @@ -672,7 +767,8 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { startFunctionRegistry, finishFunctionRegistry, addTearDownFunction, - splitListener) -> { + splitListener, + bundleFinalizer) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); finishFunctionRegistry.register( @@ -746,6 +842,7 @@ public void testPendingStateCallsBlockTillCompletion() throws Exception { fnApiRegistry::get, beamFnDataClient, mockBeamFnStateGrpcClient, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, new PTransformRunnerFactory() { @@ -764,7 +861,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { startFunctionRegistry.register( pTransformId, () -> doStateCalls(beamFnStateClient)); @@ -807,6 +905,7 @@ public void testStateCallsFailIfNoStateApiServiceDescriptorSpecified() throws Ex fnApiRegistry::get, beamFnDataClient, null /* beamFnStateGrpcClientCache */, + null /* finalizeBundleHandler */, ImmutableMap.of( DATA_INPUT_URN, new PTransformRunnerFactory() { @@ -825,7 +924,8 @@ public Object createRunnerForPTransform( PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, Consumer addTearDownFunction, - BundleSplitListener splitListener) + BundleSplitListener splitListener, + BundleFinalizer bundleFinalizer) throws IOException { startFunctionRegistry.register( pTransformId, () -> doStateCalls(beamFnStateClient));