From c5e90f1ba428b7ae35b5179616ab625aae6db91c Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Mon, 5 Jun 2017 11:01:54 -0700 Subject: [PATCH] [BEAM-1347] Migrate to Runner API constructs within the Java SDK harness --- sdks/java/harness/pom.xml | 5 + .../harness/control/ProcessBundleHandler.java | 178 +++++--- .../fn/harness/control/RegisterHandler.java | 12 +- .../data/BeamFnDataGrpcMultiplexer.java | 8 +- .../runners/core/BeamFnDataReadRunner.java | 12 +- .../runners/core/BeamFnDataWriteRunner.java | 12 +- .../runners/core/BoundedSourceRunner.java | 10 +- .../control/ProcessBundleHandlerTest.java | 400 ++++++++---------- .../harness/control/RegisterHandlerTest.java | 26 +- .../core/BeamFnDataReadRunnerTest.java | 18 +- .../core/BeamFnDataWriteRunnerTest.java | 20 +- .../runners/core/BoundedSourceRunnerTest.java | 8 +- 12 files changed, 371 insertions(+), 338 deletions(-) diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml index 3918fd9fd383..61a170ae4afb 100644 --- a/sdks/java/harness/pom.xml +++ b/sdks/java/harness/pom.xml @@ -86,6 +86,11 @@ beam-runners-google-cloud-dataflow-java + + org.apache.beam + beam-sdks-common-runner-api + + org.apache.beam beam-sdks-common-fn-api diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index fd9f0dfb534d..e33277af15bc 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -55,6 +54,7 @@ import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.runners.core.NullSideInputReader; import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -96,108 +96,145 @@ public ProcessBundleHandler( this.beamFnDataClient = beamFnDataClient; } - protected void createConsumersForPrimitiveTransform( - BeamFnApi.PrimitiveTransform primitiveTransform, + private void createRunnerAndConsumersForPTransformRecursively( + String pTransformId, + RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, - Function>>> consumers, - BiConsumer>> addConsumer, + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor, + Multimap pCollectionIdsToConsumingPTransforms, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + + // Recursively ensure that all consumers of the output PCollection have been created. + // Since we are creating the consumers first, we know that the we are building the DAG + // in reverse topological order. + for (String pCollectionId : pTransform.getOutputsMap().values()) { + // If we have created the consumers for this PCollection we can skip it. + if (pCollectionIdsToConsumers.containsKey(pCollectionId)) { + continue; + } + + for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId)) { + createRunnerAndConsumersForPTransformRecursively( + consumingPTransformId, + processBundleDescriptor.getTransformsMap().get(consumingPTransformId), + processBundleInstructionId, + processBundleDescriptor, + pCollectionIdsToConsumingPTransforms, + pCollectionIdsToConsumers, + addStartFunction, + addFinishFunction); + } + } + + createRunnerForPTransform( + pTransformId, + pTransform, + processBundleInstructionId, + processBundleDescriptor.getPcollectionsMap(), + pCollectionIdsToConsumers, + addStartFunction, + addFinishFunction); + } + + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, Consumer addFinishFunction) throws IOException { - BeamFnApi.FunctionSpec functionSpec = primitiveTransform.getFunctionSpec(); // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder>>> + ImmutableMap.Builder>>> outputMapBuilder = ImmutableMap.builder(); - for (Map.Entry entry : - primitiveTransform.getOutputsMap().entrySet()) { + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { outputMapBuilder.put( entry.getKey(), - consumers.apply( - BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransform.getId()) - .setName(entry.getKey()) - .build())); + pCollectionIdsToConsumers.get(entry.getValue())); } - ImmutableMap>>> outputMap = + ImmutableMap>>> outputMap = outputMapBuilder.build(); + // Based upon the function spec, populate the start/finish/consumer information. - ThrowingConsumer> consumer; + RunnerApi.FunctionSpec functionSpec = pTransform.getSpec(); + ThrowingConsumer> consumer; switch (functionSpec.getUrn()) { default: BeamFnApi.Target target; - BeamFnApi.Coder coderSpec; + RunnerApi.Coder coderSpec; throw new IllegalArgumentException( String.format("Unknown FunctionSpec %s", functionSpec)); case DATA_OUTPUT_URN: target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransform.getId()) - .setName(getOnlyElement(primitiveTransform.getOutputsMap().keySet())) + .setPrimitiveTransformReference(pTransformId) + .setName(getOnlyElement(pTransform.getInputsMap().keySet())) .build(); - coderSpec = (BeamFnApi.Coder) fnApiRegistry.apply( - getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference()); - BeamFnDataWriteRunner remoteGrpcWriteRunner = - new BeamFnDataWriteRunner<>( + coderSpec = (RunnerApi.Coder) fnApiRegistry.apply( + pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId()); + BeamFnDataWriteRunner remoteGrpcWriteRunner = + new BeamFnDataWriteRunner( functionSpec, processBundleInstructionId, target, coderSpec, beamFnDataClient); addStartFunction.accept(remoteGrpcWriteRunner::registerForOutput); - consumer = remoteGrpcWriteRunner::consume; + consumer = (ThrowingConsumer) + (ThrowingConsumer>) remoteGrpcWriteRunner::consume; addFinishFunction.accept(remoteGrpcWriteRunner::close); break; case DATA_INPUT_URN: target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransform.getId()) - .setName(getOnlyElement(primitiveTransform.getInputsMap().keySet())) + .setPrimitiveTransformReference(pTransformId) + .setName(getOnlyElement(pTransform.getOutputsMap().keySet())) .build(); - coderSpec = (BeamFnApi.Coder) fnApiRegistry.apply( - getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference()); - BeamFnDataReadRunner remoteGrpcReadRunner = - new BeamFnDataReadRunner<>( + coderSpec = (RunnerApi.Coder) fnApiRegistry.apply( + pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId()); + BeamFnDataReadRunner remoteGrpcReadRunner = + new BeamFnDataReadRunner( functionSpec, processBundleInstructionId, target, coderSpec, beamFnDataClient, - outputMap); + (Map) outputMap); addStartFunction.accept(remoteGrpcReadRunner::registerInputLocation); consumer = null; addFinishFunction.accept(remoteGrpcReadRunner::blockTillReadFinishes); break; case JAVA_DO_FN_URN: - DoFnRunner doFnRunner = createDoFnRunner(functionSpec, outputMap); + DoFnRunner doFnRunner = createDoFnRunner(functionSpec, (Map) outputMap); addStartFunction.accept(doFnRunner::startBundle); + consumer = (ThrowingConsumer) + (ThrowingConsumer>) doFnRunner::processElement; addFinishFunction.accept(doFnRunner::finishBundle); - consumer = doFnRunner::processElement; break; case JAVA_SOURCE_URN: @SuppressWarnings({"unchecked", "rawtypes"}) - BoundedSourceRunner, OutputT> sourceRunner = - createBoundedSourceRunner(functionSpec, outputMap); - @SuppressWarnings({"unchecked", "rawtypes"}) - ThrowingConsumer> sourceConsumer = - (ThrowingConsumer) - (ThrowingConsumer>>) - sourceRunner::runReadLoop; + BoundedSourceRunner, Object> sourceRunner = + createBoundedSourceRunner(functionSpec, (Map) outputMap); // TODO: Remove and replace with source being sent across gRPC port addStartFunction.accept(sourceRunner::start); - consumer = (ThrowingConsumer) sourceConsumer; + consumer = (ThrowingConsumer) + (ThrowingConsumer>>) + sourceRunner::runReadLoop; break; } + // If we created a consumer, add it to the map containing PCollection ids to consumers if (consumer != null) { - for (Map.Entry entry : - primitiveTransform.getInputsMap().entrySet()) { - for (BeamFnApi.Target target : entry.getValue().getTargetList()) { - addConsumer.accept(target, consumer); - } + for (String inputPCollectionId : + pTransform.getInputsMap().values()) { + pCollectionIdsToConsumers.put(inputPCollectionId, consumer); } } } @@ -212,26 +249,43 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction BeamFnApi.ProcessBundleDescriptor bundleDescriptor = (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId); - Multimap>> outputTargetToConsumer = - HashMultimap.create(); + Multimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); + Multimap>> pCollectionIdsToConsumers = + HashMultimap.create(); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); - // We process the primitive transform list in reverse order - // because we assume that the runner provides it in topologically order. - // This means that all the start/finish functions will be in reverse topological order. - for (BeamFnApi.PrimitiveTransform primitiveTransform : - Lists.reverse(bundleDescriptor.getPrimitiveTransformList())) { - createConsumersForPrimitiveTransform( - primitiveTransform, + + // Build a multimap of PCollection ids to PTransform ids which consume said PCollections + for (Map.Entry entry + : bundleDescriptor.getTransformsMap().entrySet()) { + for (String pCollectionId : entry.getValue().getInputsMap().values()) { + pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); + } + } + + // + for (Map.Entry entry + : bundleDescriptor.getTransformsMap().entrySet()) { + // Skip anything which isn't a root + // TODO: Remove source as a root and have it be triggered by the Runner. + if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) + && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn())) { + continue; + } + + createRunnerAndConsumersForPTransformRecursively( + entry.getKey(), + entry.getValue(), request::getInstructionId, - outputTargetToConsumer::get, - outputTargetToConsumer::put, + bundleDescriptor, + pCollectionIdsToConsumingPTransforms, + pCollectionIdsToConsumers, startFunctions::add, finishFunctions::add); } - // Already in reverse order so we don't need to do anything. + // Already in reverse topological order so we don't need to do anything. for (ThrowingRunnable startFunction : startFunctions) { LOG.debug("Starting function {}", startFunction); startFunction.run(); @@ -250,11 +304,11 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction * Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}. */ private DoFnRunner createDoFnRunner( - BeamFnApi.FunctionSpec functionSpec, + RunnerApi.FunctionSpec functionSpec, Map>>> outputMap) { ByteString serializedFn; try { - serializedFn = functionSpec.getData().unpack(BytesValue.class).getValue(); + serializedFn = functionSpec.getParameter().unpack(BytesValue.class).getValue(); } catch (InvalidProtocolBufferException e) { throw new IllegalArgumentException( String.format("Unable to unwrap DoFn %s", functionSpec), e); @@ -321,7 +375,7 @@ public void output(TupleTag tag, WindowedValue output) { private , OutputT> BoundedSourceRunner createBoundedSourceRunner( - BeamFnApi.FunctionSpec functionSpec, + RunnerApi.FunctionSpec functionSpec, Map>>> outputMap) { @SuppressWarnings({"rawtypes", "unchecked"}) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java index fb0623123524..276a1200df01 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java @@ -19,12 +19,14 @@ package org.apache.beam.fn.harness.control; import com.google.protobuf.Message; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.fn.v1.BeamFnApi.RegisterResponse; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,6 +47,7 @@ public RegisterHandler() { public T getById(String id) { try { + LOG.debug("Attempting to find {}", id); @SuppressWarnings("unchecked") CompletableFuture returnValue = (CompletableFuture) computeIfAbsent(id); /* @@ -75,11 +78,12 @@ public BeamFnApi.InstructionResponse.Builder register(BeamFnApi.InstructionReque processBundleDescriptor.getId(), processBundleDescriptor.getClass()); computeIfAbsent(processBundleDescriptor.getId()).complete(processBundleDescriptor); - for (BeamFnApi.Coder coder : processBundleDescriptor.getCodersList()) { + for (Map.Entry entry + : processBundleDescriptor.getCodersyyyMap().entrySet()) { LOG.debug("Registering {} with type {}", - coder.getFunctionSpec().getId(), - coder.getClass()); - computeIfAbsent(coder.getFunctionSpec().getId()).complete(coder); + entry.getKey(), + entry.getValue().getClass()); + computeIfAbsent(entry.getKey()).complete(entry.getValue()); } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexer.java index 53dfe11cc301..15e8c0d450f4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexer.java @@ -84,7 +84,7 @@ public CompletableFuture> futureForKey( KV key) { return consumers.computeIfAbsent( key, - (KV providedKey) -> new CompletableFuture<>()); + (KV unused) -> new CompletableFuture<>()); } /** @@ -102,7 +102,11 @@ public void onNext(BeamFnApi.Elements value) { try { KV key = KV.of(data.getInstructionReference(), data.getTarget()); - futureForKey(key).get().accept(data); + CompletableFuture> consumer = futureForKey(key); + if (!consumer.isDone()) { + LOG.debug("Received data for key {} without consumer ready.", key); + } + consumer.get().accept(data); if (data.getData().isEmpty()) { consumers.remove(key); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java index e6928d1aa43b..f0fe2748d51e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.slf4j.Logger; @@ -61,14 +62,14 @@ public class BeamFnDataReadRunner { private CompletableFuture readFuture; public BeamFnDataReadRunner( - BeamFnApi.FunctionSpec functionSpec, + RunnerApi.FunctionSpec functionSpec, Supplier processBundleInstructionIdSupplier, BeamFnApi.Target inputTarget, - BeamFnApi.Coder coderSpec, + RunnerApi.Coder coderSpec, BeamFnDataClient beamFnDataClientFactory, Map>>> outputMap) throws IOException { - this.apiServiceDescriptor = functionSpec.getData().unpack(BeamFnApi.RemoteGrpcPort.class) + this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class) .getApiServiceDescriptor(); this.inputTarget = inputTarget; this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; @@ -82,8 +83,9 @@ public BeamFnDataReadRunner( CloudObject.fromSpec( OBJECT_MAPPER.readValue( coderSpec - .getFunctionSpec() - .getData() + .getSpec() + .getSpec() + .getParameter() .unpack(BytesValue.class) .getValue() .newInput(), diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java index a78da5d201d7..a48df1210a47 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; @@ -51,13 +52,13 @@ public class BeamFnDataWriteRunner { private CloseableThrowingConsumer> consumer; public BeamFnDataWriteRunner( - BeamFnApi.FunctionSpec functionSpec, + RunnerApi.FunctionSpec functionSpec, Supplier processBundleInstructionIdSupplier, BeamFnApi.Target outputTarget, - BeamFnApi.Coder coderSpec, + RunnerApi.Coder coderSpec, BeamFnDataClient beamFnDataClientFactory) throws IOException { - this.apiServiceDescriptor = functionSpec.getData().unpack(BeamFnApi.RemoteGrpcPort.class) + this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class) .getApiServiceDescriptor(); this.beamFnDataClientFactory = beamFnDataClientFactory; this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; @@ -70,8 +71,9 @@ public BeamFnDataWriteRunner( CloudObject.fromSpec( OBJECT_MAPPER.readValue( coderSpec - .getFunctionSpec() - .getData() + .getSpec() + .getSpec() + .getParameter() .unpack(BytesValue.class) .getValue() .newInput(), diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java index 9d9c4334fcc8..4d530b8f79ff 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java @@ -26,7 +26,7 @@ import java.util.Collection; import java.util.Map; import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; import org.apache.beam.sdk.options.PipelineOptions; @@ -39,12 +39,12 @@ */ public class BoundedSourceRunner, OutputT> { private final PipelineOptions pipelineOptions; - private final BeamFnApi.FunctionSpec definition; + private final RunnerApi.FunctionSpec definition; private final Collection>> consumers; public BoundedSourceRunner( PipelineOptions pipelineOptions, - BeamFnApi.FunctionSpec definition, + RunnerApi.FunctionSpec definition, Map>>> outputMap) { this.pipelineOptions = pipelineOptions; this.definition = definition; @@ -61,7 +61,7 @@ public void start() throws Exception { try { // The representation here is defined as the java serialized representation of the // bounded source object packed into a protobuf Any using a protobuf BytesValue wrapper. - byte[] bytes = definition.getData().unpack(BytesValue.class).getValue().toByteArray(); + byte[] bytes = definition.getParameter().unpack(BytesValue.class).getValue().toByteArray(); @SuppressWarnings("unchecked") InputT boundedSource = (InputT) SerializableUtils.deserializeFromByteArray(bytes, definition.toString()); @@ -69,7 +69,7 @@ public void start() throws Exception { } catch (InvalidProtocolBufferException e) { throw new IOException( String.format("Failed to decode %s, expected %s", - definition.getData().getTypeUrl(), BytesValue.getDescriptor().getFullName()), + definition.getParameter().getTypeUrl(), BytesValue.getDescriptor().getFullName()), e); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index f40572843bee..562f91fdd210 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -21,9 +21,9 @@ import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -39,8 +39,6 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.protobuf.Any; @@ -49,14 +47,11 @@ import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; @@ -68,7 +63,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -105,22 +100,27 @@ public class ProcessBundleHandlerTest { .setId("58L") .setUrl("TestUrl")) .build(); - private static final BeamFnApi.Coder LONG_CODER_SPEC; - private static final BeamFnApi.Coder STRING_CODER_SPEC; + private static final RunnerApi.Coder LONG_CODER_SPEC; + private static final RunnerApi.Coder STRING_CODER_SPEC; static { try { - STRING_CODER_SPEC = - BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() - .setId(STRING_CODER_SPEC_ID) - .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))).build()))) + STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) + .build()))) + .build()) .build(); - LONG_CODER_SPEC = - BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() - .setId(STRING_CODER_SPEC_ID) - .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(WindowedValue.getFullCoder( - VarLongCoder.of(), GlobalWindow.Coder.INSTANCE))))).build()))) + LONG_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes( + CloudObjects.asCloudObject(WindowedValue.getFullCoder(VarLongCoder.of(), + GlobalWindow.Coder.INSTANCE))))) + .build()))) + .build()) .build(); } catch (IOException e) { throw new ExceptionInInitializerError(e); @@ -146,12 +146,19 @@ public void setUp() { public void testOrderOfStartAndFinishCalls() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder() - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("2L")) - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("3L")) - .build(); + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .putOutputs("2L-output", "2L-output-pc") + .build()) + .putTransforms("3L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()) + .putInputs("3L-input", "2L-output-pc") + .build()) + .putPcollections("2L-output-pc", RunnerApi.PCollection.getDefaultInstance()) + .build(); Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); - List transformsProcessed = new ArrayList<>(); + List transformsProcessed = new ArrayList<>(); List orderOfOperations = new ArrayList<>(); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -159,23 +166,22 @@ public void testOrderOfStartAndFinishCalls() throws Exception { fnApiRegistry::get, beamFnDataClient) { @Override - protected void createConsumersForPrimitiveTransform( - BeamFnApi.PrimitiveTransform primitiveTransform, + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, - Function>>> consumers, - BiConsumer>> addConsumer, + Map pCollections, + Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, - Consumer addFinishFunction) - throws IOException { + Consumer addFinishFunction) throws IOException { assertThat(processBundleInstructionId.get(), equalTo("999L")); - transformsProcessed.add(primitiveTransform); + transformsProcessed.add(pTransform); addStartFunction.accept( - () -> orderOfOperations.add("Start" + primitiveTransform.getId())); + () -> orderOfOperations.add("Start" + pTransformId)); addFinishFunction.accept( - () -> orderOfOperations.add("Finish" + primitiveTransform.getId())); + () -> orderOfOperations.add("Finish" + pTransformId)); } }; handler.processBundle(BeamFnApi.InstructionRequest.newBuilder() @@ -184,21 +190,22 @@ protected void createConsumersForPrimitiveTransform( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) .build()); - // Processing of primitive transforms is performed in reverse order. + // Processing of transforms is performed in reverse order. assertThat(transformsProcessed, contains( - processBundleDescriptor.getPrimitiveTransform(1), - processBundleDescriptor.getPrimitiveTransform(0))); + processBundleDescriptor.getTransformsMap().get("3L"), + processBundleDescriptor.getTransformsMap().get("2L"))); // Start should occur in reverse order while finish calls should occur in forward order assertThat(orderOfOperations, contains("Start3L", "Start2L", "Finish2L", "Finish3L")); } @Test - public void testCreatingPrimitiveTransformExceptionsArePropagated() throws Exception { + public void testCreatingPTransformExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder() - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("2L")) - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("3L")) - .build(); + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -206,15 +213,14 @@ public void testCreatingPrimitiveTransformExceptionsArePropagated() throws Excep fnApiRegistry::get, beamFnDataClient) { @Override - protected void createConsumersForPrimitiveTransform( - BeamFnApi.PrimitiveTransform primitiveTransform, + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, - Function>>> consumers, - BiConsumer>> addConsumer, + Map pCollections, + Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, - Consumer addFinishFunction) - throws IOException { + Consumer addFinishFunction) throws IOException { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); throw new IllegalStateException("TestException"); @@ -223,16 +229,17 @@ protected void createConsumersForPrimitiveTransform( handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) - .build()); + .build()); } @Test - public void testPrimitiveTransformStartExceptionsArePropagated() throws Exception { + public void testPTransformStartExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder() - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("2L")) - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("3L")) - .build(); + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -240,15 +247,14 @@ public void testPrimitiveTransformStartExceptionsArePropagated() throws Exceptio fnApiRegistry::get, beamFnDataClient) { @Override - protected void createConsumersForPrimitiveTransform( - BeamFnApi.PrimitiveTransform primitiveTransform, + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, - Function>>> consumers, - BiConsumer>> addConsumer, + Map pCollections, + Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, - Consumer addFinishFunction) - throws IOException { + Consumer addFinishFunction) throws IOException { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); addStartFunction.accept(this::throwException); @@ -261,16 +267,17 @@ private void throwException() { handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) - .build()); + .build()); } @Test - public void testPrimitiveTransformFinishExceptionsArePropagated() throws Exception { + public void testPTransformFinishExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder() - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("2L")) - .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId("3L")) - .build(); + .putTransforms("2L", RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -278,15 +285,14 @@ public void testPrimitiveTransformFinishExceptionsArePropagated() throws Excepti fnApiRegistry::get, beamFnDataClient) { @Override - protected void createConsumersForPrimitiveTransform( - BeamFnApi.PrimitiveTransform primitiveTransform, + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, - Function>>> consumers, - BiConsumer>> addConsumer, + Map pCollections, + Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, - Consumer addFinishFunction) - throws IOException { + Consumer addFinishFunction) throws IOException { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); addFinishFunction.accept(this::throwException); @@ -299,7 +305,7 @@ private void throwException() { handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) - .build()); + .build()); } private static class TestDoFn extends DoFn { @@ -332,72 +338,40 @@ public void finishBundle(FinishBundleContext context) { @Test public void testCreatingAndProcessingDoFn() throws Exception { Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String primitiveTransformId = "100L"; - long mainOutputId = 101L; - long additionalOutputId = 102L; + String pTransformId = "100L"; + String mainOutputId = "101"; + String additionalOutputId = "102"; DoFnInfo doFnInfo = DoFnInfo.forFn( new TestDoFn(), WindowingStrategy.globalDefault(), ImmutableList.of(), StringUtf8Coder.of(), - mainOutputId, + Long.parseLong(mainOutputId), ImmutableMap.of( - mainOutputId, TestDoFn.mainOutput, - additionalOutputId, TestDoFn.additionalOutput)); - BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() - .setId("1L") + Long.parseLong(mainOutputId), TestDoFn.mainOutput, + Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(JAVA_DO_FN_URN) - .setData(Any.pack(BytesValue.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder() .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) .build())) .build(); - BeamFnApi.Target inputATarget1 = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1000L") - .setName("inputATarget1") - .build(); - BeamFnApi.Target inputATarget2 = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1001L") - .setName("inputATarget1") - .build(); - BeamFnApi.Target inputBTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1002L") - .setName("inputBTarget") - .build(); - BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() - .setId(primitiveTransformId) - .setFunctionSpec(functionSpec) - .putInputs("inputA", BeamFnApi.Target.List.newBuilder() - .addTarget(inputATarget1) - .addTarget(inputATarget2) - .build()) - .putInputs("inputB", BeamFnApi.Target.List.newBuilder() - .addTarget(inputBTarget) - .build()) - .putOutputs(Long.toString(mainOutputId), BeamFnApi.PCollection.newBuilder() - .setCoderReference(STRING_CODER_SPEC_ID) - .build()) - .putOutputs(Long.toString(additionalOutputId), BeamFnApi.PCollection.newBuilder() - .setCoderReference(STRING_CODER_SPEC_ID) - .build()) + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("inputA", "inputATarget") + .putInputs("inputB", "inputBTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs(additionalOutputId, "additionalOutputTarget") .build(); List> mainOutputValues = new ArrayList<>(); List> additionalOutputValues = new ArrayList<>(); - BeamFnApi.Target mainOutputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName(Long.toString(mainOutputId)) - .build(); - BeamFnApi.Target additionalOutputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName(Long.toString(additionalOutputId)) - .build(); - Multimap>> existingConsumers = - ImmutableMultimap.of( - mainOutputTarget, mainOutputValues::add, - additionalOutputTarget, additionalOutputValues::add); - Multimap>> newConsumers = - HashMultimap.create(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer>) mainOutputValues::add); + consumers.put("additionalOutputTarget", + (ThrowingConsumer) (ThrowingConsumer>) additionalOutputValues::add); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); @@ -405,23 +379,24 @@ public void testCreatingAndProcessingDoFn() throws Exception { PipelineOptionsFactory.create(), fnApiRegistry::get, beamFnDataClient); - handler.createConsumersForPrimitiveTransform( - primitiveTransform, + handler.createRunnerForPTransform( + pTransformId, + pTransform, Suppliers.ofInstance("57L")::get, - existingConsumers::get, - newConsumers::put, + ImmutableMap.of(), + consumers, startFunctions::add, finishFunctions::add); Iterables.getOnlyElement(startFunctions).run(); mainOutputValues.clear(); - assertEquals(newConsumers.keySet(), - ImmutableSet.of(inputATarget1, inputATarget2, inputBTarget)); + assertThat(consumers.keySet(), containsInAnyOrder( + "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); - Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("A1")); - Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("A2")); - Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("B")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); assertThat(mainOutputValues, contains( valueInGlobalWindow("MainOutputA1"), valueInGlobalWindow("MainOutputA2"), @@ -444,44 +419,26 @@ public void testCreatingAndProcessingDoFn() throws Exception { @Test public void testCreatingAndProcessingSource() throws Exception { Map fnApiRegistry = ImmutableMap.of(LONG_CODER_SPEC_ID, LONG_CODER_SPEC); - String primitiveTransformId = "100L"; - long outputId = 101L; - - BeamFnApi.Target inputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1000L") - .setName("inputTarget") - .build(); - List> outputValues = new ArrayList<>(); - BeamFnApi.Target outputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName(Long.toString(outputId)) - .build(); - Multimap>> existingConsumers = - ImmutableMultimap.of(outputTarget, outputValues::add); - Multimap>>> newConsumers = - HashMultimap.create(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); - BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() - .setId("1L") + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(JAVA_SOURCE_URN) - .setData(Any.pack(BytesValue.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder() .setValue(ByteString.copyFrom( SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) .build())) .build(); - BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() - .setId(primitiveTransformId) - .setFunctionSpec(functionSpec) - .putInputs("input", - BeamFnApi.Target.List.newBuilder().addTarget(inputTarget).build()) - .putOutputs(Long.toString(outputId), - BeamFnApi.PCollection.newBuilder().setCoderReference(LONG_CODER_SPEC_ID).build()) + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputPC") + .putOutputs("output", "outputPC") .build(); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -489,11 +446,12 @@ public void testCreatingAndProcessingSource() throws Exception { fnApiRegistry::get, beamFnDataClient); - handler.createConsumersForPrimitiveTransform( - primitiveTransform, + handler.createRunnerForPTransform( + "pTransformId", + pTransform, Suppliers.ofInstance("57L")::get, - existingConsumers::get, - newConsumers::put, + ImmutableMap.of(), + consumers, startFunctions::add, finishFunctions::add); @@ -507,8 +465,8 @@ public void testCreatingAndProcessingSource() throws Exception { outputValues.clear(); // Check that when passing a source along as an input, the source is processed. - assertEquals(newConsumers.keySet(), ImmutableSet.of(inputTarget)); - Iterables.getOnlyElement(newConsumers.get(inputTarget)).accept( + assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept( valueInGlobalWindow(CountingSource.upTo(2))); assertThat(outputValues, contains( valueInGlobalWindow(0L), @@ -520,35 +478,25 @@ public void testCreatingAndProcessingSource() throws Exception { @Test public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String bundleId = "57L"; - String primitiveTransformId = "100L"; - long outputId = 101L; + String bundleId = "57"; + String outputId = "101"; List> outputValues = new ArrayList<>(); - BeamFnApi.Target outputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName(Long.toString(outputId)) - .build(); - Multimap>> existingConsumers = - ImmutableMultimap.of(outputTarget, outputValues::add); - Multimap>> newConsumers = - HashMultimap.create(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); - BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() - .setId("1L") + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(DATA_INPUT_URN) - .setData(Any.pack(REMOTE_PORT)) + .setParameter(Any.pack(REMOTE_PORT)) .build(); - BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() - .setId(primitiveTransformId) - .setFunctionSpec(functionSpec) - .putInputs("input", BeamFnApi.Target.List.getDefaultInstance()) - .putOutputs(Long.toString(outputId), - BeamFnApi.PCollection.newBuilder().setCoderReference(STRING_CODER_SPEC_ID).build()) + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putOutputs(outputId, "outputPC") .build(); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -556,11 +504,13 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { fnApiRegistry::get, beamFnDataClient); - handler.createConsumersForPrimitiveTransform( - primitiveTransform, + handler.createRunnerForPTransform( + "pTransformId", + pTransform, Suppliers.ofInstance(bundleId)::get, - existingConsumers::get, - newConsumers::put, + ImmutableMap.of("outputPC", + RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), + consumers, startFunctions::add, finishFunctions::add); @@ -573,8 +523,8 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { verify(beamFnDataClient).forInboundConsumer( eq(REMOTE_PORT.getApiServiceDescriptor()), eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName("input") + .setPrimitiveTransformReference("pTransformId") + .setName(outputId) .build())), eq(STRING_CODER), consumerCaptor.capture()); @@ -583,7 +533,7 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); outputValues.clear(); - assertThat(newConsumers.keySet(), empty()); + assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); completionFuture.complete(null); Iterables.getOnlyElement(finishFunctions).run(); @@ -595,33 +545,20 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); String bundleId = "57L"; - String primitiveTransformId = "100L"; - long outputId = 101L; - - BeamFnApi.Target inputTarget = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1000L") - .setName("inputTarget") - .build(); + String inputId = "100L"; - Multimap>> existingConsumers = - ImmutableMultimap.of(); - Multimap>> newConsumers = - HashMultimap.create(); + Multimap>> consumers = HashMultimap.create(); List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); - BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() - .setId("1L") + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() .setUrn(DATA_OUTPUT_URN) - .setData(Any.pack(REMOTE_PORT)) + .setParameter(Any.pack(REMOTE_PORT)) .build(); - BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() - .setId(primitiveTransformId) - .setFunctionSpec(functionSpec) - .putInputs("input", BeamFnApi.Target.List.newBuilder().addTarget(inputTarget).build()) - .putOutputs(Long.toString(outputId), - BeamFnApi.PCollection.newBuilder().setCoderReference(STRING_CODER_SPEC_ID).build()) + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs(inputId, "inputPC") .build(); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -629,11 +566,13 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { fnApiRegistry::get, beamFnDataClient); - handler.createConsumersForPrimitiveTransform( - primitiveTransform, + handler.createRunnerForPTransform( + "ptransformId", + pTransform, Suppliers.ofInstance(bundleId)::get, - existingConsumers::get, - newConsumers::put, + ImmutableMap.of("inputPC", + RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), + consumers, startFunctions::add, finishFunctions::add); @@ -643,16 +582,16 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { AtomicBoolean wasCloseCalled = new AtomicBoolean(); CloseableThrowingConsumer> outputConsumer = new CloseableThrowingConsumer>(){ - @Override - public void close() throws Exception { - wasCloseCalled.set(true); - } + @Override + public void close() throws Exception { + wasCloseCalled.set(true); + } - @Override - public void accept(WindowedValue t) throws Exception { - outputValues.add(t); - } - }; + @Override + public void accept(WindowedValue t) throws Exception { + outputValues.add(t); + } + }; when(beamFnDataClient.forOutboundConsumer( any(), @@ -662,14 +601,13 @@ public void accept(WindowedValue t) throws Exception { verify(beamFnDataClient).forOutboundConsumer( eq(REMOTE_PORT.getApiServiceDescriptor()), eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(primitiveTransformId) - .setName(Long.toString(outputId)) + .setPrimitiveTransformReference("ptransformId") + .setName(inputId) .build())), eq(STRING_CODER)); - assertEquals(newConsumers.keySet(), ImmutableSet.of(inputTarget)); - Iterables.getOnlyElement(newConsumers.get(inputTarget)).accept( - valueInGlobalWindow("TestValue")); + assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); outputValues.clear(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java index c32fcc4d6ac7..b1f441030fdb 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java @@ -27,6 +27,7 @@ import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.fn.v1.BeamFnApi.RegisterResponse; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -41,12 +42,21 @@ public class RegisterHandlerTest { BeamFnApi.InstructionRequest.newBuilder() .setInstructionId("1L") .setRegister(BeamFnApi.RegisterRequest.newBuilder() - .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId("1L") - .addCoders(BeamFnApi.Coder.newBuilder().setFunctionSpec( - BeamFnApi.FunctionSpec.newBuilder().setId("10L")).build())) + .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder() + .setId("1L") + .putCodersyyy("10L", RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:10L").build()) + .build()) + .build()) + .build()) .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId("2L") - .addCoders(BeamFnApi.Coder.newBuilder().setFunctionSpec( - BeamFnApi.FunctionSpec.newBuilder().setId("20L")).build())) + .putCodersyyy("20L", RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:20L").build()) + .build()) + .build()) + .build()) .build()) .build(); private static final BeamFnApi.InstructionResponse REGISTER_RESPONSE = @@ -71,9 +81,11 @@ public BeamFnApi.InstructionResponse call() throws Exception { handler.getById("1L")); assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1), handler.getById("2L")); - assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCoders(0), + assertEquals( + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCodersyyyOrThrow("10L"), handler.getById("10L")); - assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCoders(0), + assertEquals( + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCodersyyyOrThrow("20L"), handler.getById("20L")); assertEquals(REGISTER_RESPONSE, responseFuture.get()); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java index a3d4a1b0f109..7e8ab1a2216d 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java @@ -51,6 +51,7 @@ import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; @@ -71,16 +72,21 @@ public class BeamFnDataReadRunnerTest { private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); - private static final BeamFnApi.FunctionSpec FUNCTION_SPEC = BeamFnApi.FunctionSpec.newBuilder() - .setData(Any.pack(PORT_SPEC)).build(); + private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(PORT_SPEC)).build(); private static final Coder> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final BeamFnApi.Coder CODER_SPEC; + private static final RunnerApi.Coder CODER_SPEC; static { try { - CODER_SPEC = BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() - .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))).build()))) + CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( + RunnerApi.SdkFunctionSpec.newBuilder().setSpec( + RunnerApi.FunctionSpec.newBuilder().setParameter( + Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))) + .build())) + .build()) + .build()) .build(); } catch (IOException e) { throw new ExceptionInInitializerError(e); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java index 338396650b41..a3c874e54588 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java @@ -41,6 +41,7 @@ import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; @@ -59,17 +60,22 @@ public class BeamFnDataWriteRunnerTest { private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); - private static final BeamFnApi.FunctionSpec FUNCTION_SPEC = BeamFnApi.FunctionSpec.newBuilder() - .setData(Any.pack(PORT_SPEC)).build(); + private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(PORT_SPEC)).build(); private static final Coder> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final BeamFnApi.Coder CODER_SPEC; + private static final RunnerApi.Coder CODER_SPEC; static { try { - CODER_SPEC = BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() - .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))).build()))) - .build(); + CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( + RunnerApi.SdkFunctionSpec.newBuilder().setSpec( + RunnerApi.FunctionSpec.newBuilder().setParameter( + Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))) + .build())) + .build()) + .build()) + .build(); } catch (IOException e) { throw new ExceptionInInitializerError(e); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java index 73860efc63b0..d8ed121a7041 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -58,7 +58,7 @@ public void testRunReadLoopWithMultipleSources() throws Exception { BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), - BeamFnApi.FunctionSpec.getDefaultInstance(), + RunnerApi.FunctionSpec.getDefaultInstance(), outputMap); runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(2))); @@ -81,7 +81,7 @@ public void testRunReadLoopWithEmptySource() throws Exception { BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), - BeamFnApi.FunctionSpec.getDefaultInstance(), + RunnerApi.FunctionSpec.getDefaultInstance(), outputMap); runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(0))); @@ -101,7 +101,7 @@ public void testStart() throws Exception { BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), - BeamFnApi.FunctionSpec.newBuilder().setData( + RunnerApi.FunctionSpec.newBuilder().setParameter( Any.pack(BytesValue.newBuilder().setValue(encodedSource).build())).build(), outputMap);