From 4299e388b7b391275003d4e7fad2ff98d559c7e6 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Fri, 20 Sep 2019 16:09:19 -0700 Subject: [PATCH] Sets workerHarnessContaienrImage as the containerImage of the DockerPayload of the default environment for DataflowRunner --- .../dataflow/DataflowPipelineTranslator.java | 29 ++++++++++++++---- .../DataflowPipelineTranslatorTest.java | 30 +++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index c0c5e553036e..41e5cbbff50e 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -74,6 +74,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; @@ -121,10 +122,17 @@ public class DataflowPipelineTranslator { private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineTranslator.class); private static final ObjectMapper MAPPER = new ObjectMapper(); - private static byte[] serializeWindowingStrategy(WindowingStrategy windowingStrategy) { + private static byte[] serializeWindowingStrategy( + WindowingStrategy windowingStrategy, PipelineOptions options) { try { SdkComponents sdkComponents = SdkComponents.create(); - sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT); + + String workerHarnessContainerImageURL = + DataflowRunner.getContainerImageForJob(options.as(DataflowPipelineOptions.class)); + RunnerApi.Environment defaultEnvironmentForDataflow = + Environments.createDockerEnvironment(workerHarnessContainerImageURL); + sdkComponents.registerEnvironment(defaultEnvironmentForDataflow); + return WindowingStrategyTranslation.toMessageProto(windowingStrategy, sdkComponents) .toByteArray(); } catch (Exception e) { @@ -164,7 +172,13 @@ public JobSpecification translate( // Capture the sdkComponents for look up during step translations SdkComponents sdkComponents = SdkComponents.create(); - sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT); + + String workerHarnessContainerImageURL = + DataflowRunner.getContainerImageForJob(options.as(DataflowPipelineOptions.class)); + RunnerApi.Environment defaultEnvironmentForDataflow = + Environments.createDockerEnvironment(workerHarnessContainerImageURL); + sdkComponents.registerEnvironment(defaultEnvironmentForDataflow); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, sdkComponents, true); LOG.debug("Portable pipeline proto:\n{}", TextFormat.printToString(pipelineProto)); @@ -754,7 +768,8 @@ private void translateTyped( WindowingStrategy windowingStrategy = input.getWindowingStrategy(); stepContext.addInput( PropertyNames.WINDOWING_STRATEGY, - byteArrayToJsonString(serializeWindowingStrategy(windowingStrategy))); + byteArrayToJsonString( + serializeWindowingStrategy(windowingStrategy, context.getPipelineOptions()))); stepContext.addInput( PropertyNames.IS_MERGING_WINDOW_FN, !windowingStrategy.getWindowFn().isNonMerging()); @@ -898,7 +913,8 @@ private void groupByKeyHelper( stepContext.addInput(PropertyNames.DISALLOW_COMBINER_LIFTING, !allowCombinerLifting); stepContext.addInput( PropertyNames.SERIALIZED_FN, - byteArrayToJsonString(serializeWindowingStrategy(windowingStrategy))); + byteArrayToJsonString( + serializeWindowingStrategy(windowingStrategy, context.getPipelineOptions()))); stepContext.addInput( PropertyNames.IS_MERGING_WINDOW_FN, !windowingStrategy.getWindowFn().isNonMerging()); @@ -1039,7 +1055,8 @@ private void translateHelper(Window.Assign transform, TranslationContext stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); WindowingStrategy strategy = context.getOutput(transform).getWindowingStrategy(); - byte[] serializedBytes = serializeWindowingStrategy(strategy); + byte[] serializedBytes = + serializeWindowingStrategy(strategy, context.getPipelineOptions()); String serializedJson = byteArrayToJsonString(serializedBytes); stepContext.addInput(PropertyNames.SERIALIZED_FN, serializedJson); } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index f8cc7b6e23ab..85b1e22cbe77 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -50,6 +50,8 @@ import java.util.Set; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Components; +import org.apache.beam.model.pipeline.v1.RunnerApi.DockerPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.Environment; import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ParDoTranslation; @@ -956,6 +958,34 @@ public void populateDisplayData(DisplayData.Builder builder) { assertEquals(expectedFn2DisplayData, ImmutableSet.copyOf(fn2displayData)); } + /** + * Tests that when {@link DataflowPipelineOptions#setWorkerHarnessContainerImage(String)} pipeline + * option is set, {@link DataflowRunner} sets that value as the {@link + * DockerPayload#getContainerImage()} of the default {@link Environment} used when generating the + * model pipeline proto. + */ + @Test + public void testSetWorkerHarnessContainerImageInPipelineProto() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + String containerImage = "gcr.io/IMAGE/foo"; + options.as(DataflowPipelineOptions.class).setWorkerHarnessContainerImage(containerImage); + + JobSpecification specification = + DataflowPipelineTranslator.fromOptions(options) + .translate( + Pipeline.create(options), + DataflowRunner.fromOptions(options), + Collections.emptyList()); + RunnerApi.Pipeline pipelineProto = specification.getPipelineProto(); + + assertEquals(1, pipelineProto.getComponents().getEnvironmentsCount()); + Environment defaultEnvironment = + Iterables.getOnlyElement(pipelineProto.getComponents().getEnvironmentsMap().values()); + + DockerPayload payload = DockerPayload.parseFrom(defaultEnvironment.getPayload()); + assertEquals(DataflowRunner.getContainerImageForJob(options), payload.getContainerImage()); + } + private static void assertAllStepOutputsHaveUniqueIds(Job job) throws Exception { List outputIds = new ArrayList<>(); for (Step step : job.getSteps()) {