From 7f8c6e8541d37a4f4ee79bbc14e3f43a38d261c6 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Wed, 30 Aug 2017 13:56:40 -0700 Subject: [PATCH] [BEAM-1347] Wire up the BeamFnStateGrpcClientCache implementation into the ProcessBundleHandler Add a BeamFnStateClient that is dependent on whether the State API service descriptor is populated. --- .../fn-api/src/main/proto/beam_fn_api.proto | 5 + .../org/apache/beam/fn/harness/FnHarness.java | 9 +- .../harness/control/ProcessBundleHandler.java | 147 +++++++++++++---- .../control/ProcessBundleHandlerTest.java | 153 +++++++++++++++++- 4 files changed, 279 insertions(+), 35 deletions(-) diff --git a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto index 9da5afec1b4e..53d67bce3f01 100644 --- a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto +++ b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto @@ -168,6 +168,11 @@ message ProcessBundleDescriptor { // (Required) A map from pipeline-scoped id to Environment. map environments = 6; + + // A descriptor describing the end point to use for State API + // calls. Required if the Runner intends to send remote references over the + // data plane or if any of the transforms rely on user state or side inputs. + ApiServiceDescriptor state_api_service_descriptor = 7; } // A request to process a given bundle. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index a79ecca858ff..49a7a882773a 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -29,6 +29,7 @@ import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; import org.apache.beam.fn.harness.fn.ThrowingFunction; import org.apache.beam.fn.harness.logging.BeamFnLoggingClient; +import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.fn.harness.stream.StreamObserverFactory; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.sdk.extensions.gcp.options.GcsOptions; @@ -109,11 +110,17 @@ public static void main(PipelineOptions options, BeamFnDataGrpcClient beamFnDataMultiplexer = new BeamFnDataGrpcClient( options, channelFactory::forDescriptor, streamObserverFactory::from); + BeamFnStateGrpcClientCache beamFnStateGrpcClientCache = new BeamFnStateGrpcClientCache( + options, + IdGenerator::generate, + channelFactory::forDescriptor, + streamObserverFactory::from); + ProcessBundleHandler processBundleHandler = new ProcessBundleHandler( options, fnApiRegistry::getById, beamFnDataMultiplexer, - null /* beamFnStateClient */); + beamFnStateGrpcClientCache); handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER, fnApiRegistry::register); handlers.put(BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 67c4d6778d8d..e094487f1275 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -31,6 +31,8 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Phaser; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -40,7 +42,13 @@ import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.harness.state.BeamFnStateClient; +import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnApi.ApiServiceDescriptor; +import org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest.Builder; +import org.apache.beam.fn.v1.BeamFnApi.StateResponse; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; @@ -84,7 +92,7 @@ public class ProcessBundleHandler { private final PipelineOptions options; private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; - private final BeamFnStateClient beamFnStateClient; + private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final Map urnToPTransformRunnerFactoryMap; private final PTransformRunnerFactory defaultPTransformRunnerFactory; @@ -93,8 +101,12 @@ public ProcessBundleHandler( PipelineOptions options, Function fnApiRegistry, BeamFnDataClient beamFnDataClient, - BeamFnStateClient beamFnStateClient) { - this(options, fnApiRegistry, beamFnDataClient, beamFnStateClient, REGISTERED_RUNNER_FACTORIES); + BeamFnStateGrpcClientCache beamFnStateGrpcClientCache) { + this(options, + fnApiRegistry, + beamFnDataClient, + beamFnStateGrpcClientCache, + REGISTERED_RUNNER_FACTORIES); } @VisibleForTesting @@ -102,12 +114,12 @@ public ProcessBundleHandler( PipelineOptions options, Function fnApiRegistry, BeamFnDataClient beamFnDataClient, - BeamFnStateClient beamFnStateClient, + BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, Map urnToPTransformRunnerFactoryMap) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; - this.beamFnStateClient = beamFnStateClient; + this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = new PTransformRunnerFactory() { @Override @@ -132,6 +144,7 @@ public Object createRunnerForPTransform( } private void createRunnerAndConsumersForPTransformRecursively( + BeamFnStateClient beamFnStateClient, String pTransformId, RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, @@ -152,6 +165,7 @@ private void createRunnerAndConsumersForPTransformRecursively( for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId)) { createRunnerAndConsumersForPTransformRecursively( + beamFnStateClient, consumingPTransformId, processBundleDescriptor.getTransformsMap().get(consumingPTransformId), processBundleInstructionId, @@ -204,39 +218,110 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction } } - // - for (Map.Entry entry - : bundleDescriptor.getTransformsMap().entrySet()) { - // Skip anything which isn't a root - // TODO: Remove source as a root and have it be triggered by the Runner. - if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) - && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn())) { - continue; + // Instantiate a State API call handler depending on whether a State Api service descriptor + // was specified. + try (HandleStateCallsForBundle beamFnStateClient = + bundleDescriptor.hasStateApiServiceDescriptor() + ? new BlockTillStateCallsFinish(beamFnStateGrpcClientCache.forApiServiceDescriptor( + bundleDescriptor.getStateApiServiceDescriptor())) + : new FailAllStateCallsForBundle(request.getProcessBundle())) { + // Create a BeamFnStateClient + for (Map.Entry entry + : bundleDescriptor.getTransformsMap().entrySet()) { + // Skip anything which isn't a root + // TODO: Remove source as a root and have it be triggered by the Runner. + if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) + && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn())) { + continue; + } + + createRunnerAndConsumersForPTransformRecursively( + beamFnStateClient, + entry.getKey(), + entry.getValue(), + request::getInstructionId, + bundleDescriptor, + pCollectionIdsToConsumingPTransforms, + pCollectionIdsToConsumers, + startFunctions::add, + finishFunctions::add); + } + + // Already in reverse topological order so we don't need to do anything. + for (ThrowingRunnable startFunction : startFunctions) { + LOG.debug("Starting function {}", startFunction); + startFunction.run(); } - createRunnerAndConsumersForPTransformRecursively( - entry.getKey(), - entry.getValue(), - request::getInstructionId, - bundleDescriptor, - pCollectionIdsToConsumingPTransforms, - pCollectionIdsToConsumers, - startFunctions::add, - finishFunctions::add); + // Need to reverse this since we want to call finish in topological order. + for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) { + LOG.debug("Finishing function {}", finishFunction); + finishFunction.run(); + } } - // Already in reverse topological order so we don't need to do anything. - for (ThrowingRunnable startFunction : startFunctions) { - LOG.debug("Starting function {}", startFunction); - startFunction.run(); + return response; + } + + /** + * A {@link BeamFnStateClient} which counts the number of outstanding {@link StateRequest}s and + * blocks till they are all finished. + */ + private class BlockTillStateCallsFinish extends HandleStateCallsForBundle { + private final BeamFnStateClient beamFnStateClient; + private final Phaser phaser; + private int currentPhase; + + private BlockTillStateCallsFinish(BeamFnStateClient beamFnStateClient) { + this.beamFnStateClient = beamFnStateClient; + this.phaser = new Phaser(1 /* initial party is the process bundle handler */); + this.currentPhase = phaser.getPhase(); } - // Need to reverse this since we want to call finish in topological order. - for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) { - LOG.debug("Finishing function {}", finishFunction); - finishFunction.run(); + @Override + public void close() throws Exception { + int unarrivedParties = phaser.getUnarrivedParties(); + if (unarrivedParties > 0) { + LOG.debug("Waiting for {} parties to arrive before closing, current phase {}.", + unarrivedParties, currentPhase); + } + currentPhase = phaser.arriveAndAwaitAdvance(); } - return response; + @Override + public void handle(StateRequest.Builder requestBuilder, + CompletableFuture response) { + // Register each request with the phaser and arrive and deregister each time a request + // completes. + phaser.register(); + response.whenComplete((stateResponse, throwable) -> phaser.arriveAndDeregister()); + beamFnStateClient.handle(requestBuilder, response); + } + } + + /** + * A {@link BeamFnStateClient} which fails all requests because the {@link ProcessBundleRequest} + * does not contain a State API {@link ApiServiceDescriptor}. + */ + private class FailAllStateCallsForBundle extends HandleStateCallsForBundle { + private final ProcessBundleRequest request; + + private FailAllStateCallsForBundle(ProcessBundleRequest request) { + this.request = request; + } + + @Override + public void close() throws Exception { + // no-op + } + + @Override + public void handle(Builder requestBuilder, CompletableFuture response) { + throw new IllegalStateException(String.format("State API calls are unsupported because the " + + "ProcessBundleRequest %s does not support state.", request)); + } + } + + private abstract class HandleStateCallsForBundle implements AutoCloseable, BeamFnStateClient { } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index d0e1faf6f246..94fa6ade06c4 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -21,14 +21,21 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; @@ -36,7 +43,11 @@ import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.harness.state.BeamFnStateClient; +import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnApi.ApiServiceDescriptor; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateResponse; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -50,7 +61,10 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** Tests for {@link ProcessBundleHandler}. */ @RunWith(JUnit4.class) @@ -150,7 +164,7 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { PipelineOptionsFactory.create(), fnApiRegistry::get, beamFnDataClient, - null /* beamFnStateClient */, + null /* beamFnStateGrpcClientCache */, ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { @Override public Object createRunnerForPTransform( @@ -190,7 +204,7 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { PipelineOptionsFactory.create(), fnApiRegistry::get, beamFnDataClient, - null /* beamFnStateClient */, + null /* beamFnStateGrpcClientCache */, ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { @Override public Object createRunnerForPTransform( @@ -231,7 +245,7 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { PipelineOptionsFactory.create(), fnApiRegistry::get, beamFnDataClient, - null /* beamFnStateClient */, + null /* beamFnStateGrpcClientCache */, ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { @Override public Object createRunnerForPTransform( @@ -258,6 +272,139 @@ public Object createRunnerForPTransform( .build()); } + @Test + public void testPendingStateCallsBlockTillCompletion() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .setStateApiServiceDescriptor(ApiServiceDescriptor.getDefaultInstance()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + + CompletableFuture successfulResponse = new CompletableFuture<>(); + CompletableFuture unsuccessfulResponse = new CompletableFuture<>(); + + BeamFnStateGrpcClientCache mockBeamFnStateGrpcClient = + Mockito.mock(BeamFnStateGrpcClientCache.class); + BeamFnStateClient mockBeamFnStateClient = Mockito.mock(BeamFnStateClient.class); + when(mockBeamFnStateGrpcClient.forApiServiceDescriptor(any())) + .thenReturn(mockBeamFnStateClient); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + StateRequest.Builder stateRequestBuilder = + (StateRequest.Builder) invocation.getArguments()[0]; + CompletableFuture completableFuture = + (CompletableFuture) invocation.getArguments()[1]; + new Thread() { + @Override + public void run() { + // Simulate sleeping which introduces a race which most of the time requires + // the ProcessBundleHandler to block. + Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS); + switch (stateRequestBuilder.getInstructionReference()) { + case "SUCCESS": + completableFuture.complete(StateResponse.getDefaultInstance()); + break; + case "FAIL": + completableFuture.completeExceptionally(new RuntimeException("TEST ERROR")); + } + } + }.start(); + return null; + } + }).when(mockBeamFnStateClient).handle(any(), any()); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + mockBeamFnStateGrpcClient, + ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + BeamFnStateClient beamFnStateClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + addStartFunction.accept(() -> doStateCalls(beamFnStateClient)); + return null; + } + + private void doStateCalls(BeamFnStateClient beamFnStateClient) { + beamFnStateClient.handle(StateRequest.newBuilder().setInstructionReference("SUCCESS"), + successfulResponse); + beamFnStateClient.handle(StateRequest.newBuilder().setInstructionReference("FAIL"), + unsuccessfulResponse); + } + })); + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder() + .setProcessBundleDescriptorReference("1L")) + .build()); + + assertTrue(successfulResponse.isDone()); + assertTrue(unsuccessfulResponse.isDone()); + } + + @Test + public void testStateCallsFailIfNoStateApiServiceDescriptorSpecified() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + null /* beamFnStateGrpcClientCache */, + ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + BeamFnStateClient beamFnStateClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + addStartFunction.accept(() -> doStateCalls(beamFnStateClient)); + return null; + } + + private void doStateCalls(BeamFnStateClient beamFnStateClient) { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("State API calls are unsupported"); + beamFnStateClient.handle(StateRequest.newBuilder().setInstructionReference("SUCCESS"), + new CompletableFuture<>()); + } + })); + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) + .build()); + } + + private static void throwException() { throw new IllegalStateException("TestException"); }