diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/InProcessEnvironmentFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/InProcessEnvironmentFactory.java
index 03a3b550eea6..b4c475248d60 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/InProcessEnvironmentFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/InProcessEnvironmentFactory.java
@@ -88,6 +88,7 @@ public RemoteEnvironment createEnvironment(Environment container) throws Excepti
() -> {
try {
FnHarness.main(
+ "id",
options,
loggingServer.getApiServiceDescriptor(),
controlServer.getApiServiceDescriptor(),
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 2c95ebc585a0..25ba91e7f05b 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -131,6 +131,7 @@ public void setup() throws Exception {
sdkHarnessExecutor.submit(
() ->
FnHarness.main(
+ "id",
PipelineOptionsFactory.create(),
loggingServer.getApiServiceDescriptor(),
controlServer.getApiServiceDescriptor(),
diff --git a/runners/reference/java/src/main/java/org/apache/beam/runners/reference/testing/InProcessManagedChannelFactory.java b/runners/reference/java/src/main/java/org/apache/beam/runners/reference/testing/InProcessManagedChannelFactory.java
deleted file mode 100644
index e134aecc5beb..000000000000
--- a/runners/reference/java/src/main/java/org/apache/beam/runners/reference/testing/InProcessManagedChannelFactory.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.reference.testing;
-
-import io.grpc.ManagedChannel;
-import io.grpc.inprocess.InProcessChannelBuilder;
-import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
-import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
-
-/**
- * A {@link org.apache.beam.sdk.fn.channel.ManagedChannelFactory} that uses in-process channels.
- *
- *
The channel builder uses {@link ApiServiceDescriptor#getUrl()} as the unique in-process name.
- */
-public class InProcessManagedChannelFactory extends ManagedChannelFactory {
-
- @Override
- public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
- return InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
- }
-}
diff --git a/sdks/java/container/boot.go b/sdks/java/container/boot.go
index 1c80e0bab94e..ad7a35d2972d 100644
--- a/sdks/java/container/boot.go
+++ b/sdks/java/container/boot.go
@@ -92,6 +92,7 @@ func main() {
// (3) Invoke the Java harness, preserving artifact ordering in classpath.
+ os.Setenv("HARNESS_ID", *id)
os.Setenv("PIPELINE_OPTIONS", options)
os.Setenv("LOGGING_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pb.ApiServiceDescriptor{Url: *loggingEndpoint}))
os.Setenv("CONTROL_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pb.ApiServiceDescriptor{Url: *controlEndpoint}))
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java
index 0a4a35d58f1d..57d2c68e3635 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.fn.channel;
+import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;
@@ -26,11 +27,10 @@
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.unix.DomainSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
-/**
- * A Factory which creates an underlying {@link ManagedChannel} implementation.
- */
+/** A Factory which creates an underlying {@link ManagedChannel} implementation. */
public abstract class ManagedChannelFactory {
public static ManagedChannelFactory createDefault() {
return new Default();
@@ -41,7 +41,20 @@ public static ManagedChannelFactory createEpoll() {
return new Epoll();
}
- public abstract ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor);
+ public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
+ return builderFor(apiServiceDescriptor).build();
+ }
+
+ /** Create a {@link ManagedChannelBuilder} for the provided {@link ApiServiceDescriptor}. */
+ protected abstract ManagedChannelBuilder> builderFor(ApiServiceDescriptor descriptor);
+
+ /**
+ * Returns a {@link ManagedChannelFactory} like this one, but which will apply the provided {@link
+ * ClientInterceptor ClientInterceptors} to any channel it creates.
+ */
+ public ManagedChannelFactory withInterceptors(List interceptors) {
+ return new InterceptedManagedChannelFactory(this, interceptors);
+ }
/**
* Creates a {@link ManagedChannel} backed by an {@link EpollDomainSocketChannel} if the address
@@ -50,17 +63,18 @@ public static ManagedChannelFactory createEpoll() {
*/
private static class Epoll extends ManagedChannelFactory {
@Override
- public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
+ public ManagedChannelBuilder> builderFor(ApiServiceDescriptor apiServiceDescriptor) {
SocketAddress address = SocketAddressFactory.createFrom(apiServiceDescriptor.getUrl());
return NettyChannelBuilder.forAddress(address)
- .channelType(address instanceof DomainSocketAddress
- ? EpollDomainSocketChannel.class : EpollSocketChannel.class)
+ .channelType(
+ address instanceof DomainSocketAddress
+ ? EpollDomainSocketChannel.class
+ : EpollSocketChannel.class)
.eventLoopGroup(new EpollEventLoopGroup())
.usePlaintext(true)
// Set the message size to max value here. The actual size is governed by the
// buffer size in the layers above.
- .maxInboundMessageSize(Integer.MAX_VALUE)
- .build();
+ .maxInboundMessageSize(Integer.MAX_VALUE);
}
}
@@ -70,13 +84,38 @@ public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
*/
private static class Default extends ManagedChannelFactory {
@Override
- public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
+ public ManagedChannelBuilder> builderFor(ApiServiceDescriptor apiServiceDescriptor) {
return ManagedChannelBuilder.forTarget(apiServiceDescriptor.getUrl())
.usePlaintext(true)
// Set the message size to max value here. The actual size is governed by the
// buffer size in the layers above.
- .maxInboundMessageSize(Integer.MAX_VALUE)
- .build();
+ .maxInboundMessageSize(Integer.MAX_VALUE);
+ }
+ }
+
+ private static class InterceptedManagedChannelFactory extends ManagedChannelFactory {
+ private final ManagedChannelFactory channelFactory;
+ private final List interceptors;
+
+ private InterceptedManagedChannelFactory(
+ ManagedChannelFactory managedChannelFactory, List interceptors) {
+ this.channelFactory = managedChannelFactory;
+ this.interceptors = interceptors;
+ }
+
+ @Override
+ public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
+ return builderFor(apiServiceDescriptor).intercept(interceptors).build();
+ }
+
+ @Override
+ protected ManagedChannelBuilder> builderFor(ApiServiceDescriptor descriptor) {
+ return channelFactory.builderFor(descriptor);
+ }
+
+ @Override
+ public ManagedChannelFactory withInterceptors(List interceptors) {
+ return new InterceptedManagedChannelFactory(channelFactory, interceptors);
}
}
}
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/test/InProcessManagedChannelFactory.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/test/InProcessManagedChannelFactory.java
index 787047b7f6ef..6e0d87fe9073 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/test/InProcessManagedChannelFactory.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/test/InProcessManagedChannelFactory.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.fn.test;
-import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
import io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
@@ -35,7 +35,7 @@ public static ManagedChannelFactory create() {
private InProcessManagedChannelFactory() {}
@Override
- public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
- return InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
+ public ManagedChannelBuilder> builderFor(ApiServiceDescriptor apiServiceDescriptor) {
+ return InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl());
}
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 204e828092d6..8d036b27adc8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -49,18 +49,21 @@
* Main entry point into the Beam SDK Fn Harness for Java.
*
*
This entry point expects the following environment variables:
+ *
*
- *
LOGGING_API_SERVICE_DESCRIPTOR: A
- * {@link org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor} encoded as text
- * representing the endpoint that is to be connected to for the Beam Fn Logging service.
- *
CONTROL_API_SERVICE_DESCRIPTOR: A
- * {@link Endpoints.ApiServiceDescriptor} encoded as text
- * representing the endpoint that is to be connected to for the Beam Fn Control service.
+ *
HARNESS_ID: A String representing the ID of this FnHarness. This will be added to the
+ * headers of calls to the Beam Control Service
+ *
LOGGING_API_SERVICE_DESCRIPTOR: A {@link
+ * org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor} encoded as text
+ * representing the endpoint that is to be connected to for the Beam Fn Logging service.
+ *
CONTROL_API_SERVICE_DESCRIPTOR: A {@link Endpoints.ApiServiceDescriptor} encoded as text
+ * representing the endpoint that is to be connected to for the Beam Fn Control service.
*
PIPELINE_OPTIONS: A serialized form of {@link PipelineOptions}. See {@link PipelineOptions}
- * for further details.
+ * for further details.
*
*/
public class FnHarness {
+ private static final String HARNESS_ID = "HARNESS_ID";
private static final String CONTROL_API_SERVICE_DESCRIPTOR = "CONTROL_API_SERVICE_DESCRIPTOR";
private static final String LOGGING_API_SERVICE_DESCRIPTOR = "LOGGING_API_SERVICE_DESCRIPTOR";
private static final String PIPELINE_OPTIONS = "PIPELINE_OPTIONS";
@@ -76,14 +79,17 @@ private static Endpoints.ApiServiceDescriptor getApiServiceDescriptor(String env
public static void main(String[] args) throws Exception {
System.out.format("SDK Fn Harness started%n");
+ System.out.format("Harness ID %s%n", System.getenv(HARNESS_ID));
System.out.format("Logging location %s%n", System.getenv(LOGGING_API_SERVICE_DESCRIPTOR));
System.out.format("Control location %s%n", System.getenv(CONTROL_API_SERVICE_DESCRIPTOR));
System.out.format("Pipeline options %s%n", System.getenv(PIPELINE_OPTIONS));
- ObjectMapper objectMapper = new ObjectMapper().registerModules(
- ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
- PipelineOptions options = objectMapper.readValue(
- System.getenv(PIPELINE_OPTIONS), PipelineOptions.class);
+ String id = System.getenv(HARNESS_ID);
+ ObjectMapper objectMapper =
+ new ObjectMapper()
+ .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
+ PipelineOptions options =
+ objectMapper.readValue(System.getenv(PIPELINE_OPTIONS), PipelineOptions.class);
Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor =
getApiServiceDescriptor(LOGGING_API_SERVICE_DESCRIPTOR);
@@ -91,12 +97,15 @@ public static void main(String[] args) throws Exception {
Endpoints.ApiServiceDescriptor controlApiServiceDescriptor =
getApiServiceDescriptor(CONTROL_API_SERVICE_DESCRIPTOR);
- main(options, loggingApiServiceDescriptor, controlApiServiceDescriptor);
+ main(id, options, loggingApiServiceDescriptor, controlApiServiceDescriptor);
}
- public static void main(PipelineOptions options,
+ public static void main(
+ String id,
+ PipelineOptions options,
Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor,
- Endpoints.ApiServiceDescriptor controlApiServiceDescriptor) throws Exception {
+ Endpoints.ApiServiceDescriptor controlApiServiceDescriptor)
+ throws Exception {
ManagedChannelFactory channelFactory;
List experiments = options.as(ExperimentalOptions.class).getExperiments();
if (experiments != null && experiments.contains("beam_fn_api_epoll")) {
@@ -107,6 +116,7 @@ public static void main(PipelineOptions options,
StreamObserverFactory streamObserverFactory =
HarnessStreamObserverFactories.fromOptions(options);
main(
+ id,
options,
loggingApiServiceDescriptor,
controlApiServiceDescriptor,
@@ -115,43 +125,46 @@ public static void main(PipelineOptions options,
}
public static void main(
+ String id,
PipelineOptions options,
Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor,
Endpoints.ApiServiceDescriptor controlApiServiceDescriptor,
ManagedChannelFactory channelFactory,
StreamObserverFactory streamObserverFactory) {
IdGenerator idGenerator = IdGenerators.decrementingLongs();
- try (BeamFnLoggingClient logging = new BeamFnLoggingClient(
- options,
- loggingApiServiceDescriptor,
- channelFactory::forDescriptor)) {
+ try (BeamFnLoggingClient logging =
+ new BeamFnLoggingClient(
+ options, loggingApiServiceDescriptor, channelFactory::forDescriptor)) {
LOG.info("Fn Harness started");
- EnumMap>
handlers = new EnumMap<>(BeamFnApi.InstructionRequest.RequestCase.class);
RegisterHandler fnApiRegistry = new RegisterHandler();
- BeamFnDataGrpcClient beamFnDataMultiplexer = new BeamFnDataGrpcClient(
- options, channelFactory::forDescriptor, streamObserverFactory::from);
+ BeamFnDataGrpcClient beamFnDataMultiplexer =
+ new BeamFnDataGrpcClient(
+ options, channelFactory::forDescriptor, streamObserverFactory::from);
BeamFnStateGrpcClientCache beamFnStateGrpcClientCache =
new BeamFnStateGrpcClientCache(
options, idGenerator, channelFactory::forDescriptor, streamObserverFactory::from);
- ProcessBundleHandler processBundleHandler = new ProcessBundleHandler(
- options,
- fnApiRegistry::getById,
- beamFnDataMultiplexer,
- beamFnStateGrpcClientCache);
- handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER,
- fnApiRegistry::register);
- handlers.put(BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE,
+ ProcessBundleHandler processBundleHandler =
+ new ProcessBundleHandler(
+ options, fnApiRegistry::getById, beamFnDataMultiplexer, beamFnStateGrpcClientCache);
+ handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER, fnApiRegistry::register);
+ handlers.put(
+ BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE,
processBundleHandler::processBundle);
- BeamFnControlClient control = new BeamFnControlClient(controlApiServiceDescriptor,
- channelFactory::forDescriptor,
- streamObserverFactory::from,
- handlers);
+ BeamFnControlClient control =
+ new BeamFnControlClient(
+ id,
+ controlApiServiceDescriptor,
+ channelFactory,
+ streamObserverFactory::from,
+ handlers);
LOG.info("Entering instruction processing loop");
control.processInstructionRequests(options.as(GcsOptions.class).getExecutorService());
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/AddHarnessIdInterceptor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/AddHarnessIdInterceptor.java
new file mode 100644
index 000000000000..8a2156071715
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/AddHarnessIdInterceptor.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package org.apache.beam.fn.harness.control;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import io.grpc.ClientInterceptor;
+import io.grpc.Metadata;
+import io.grpc.Metadata.Key;
+import io.grpc.stub.MetadataUtils;
+
+/** A {@link ClientInterceptor} that attaches a provided SDK Harness ID to outgoing messages. */
+public class AddHarnessIdInterceptor {
+ private static final Key ID_KEY = Key.of("worker_id", Metadata.ASCII_STRING_MARSHALLER);
+
+ public static ClientInterceptor create(String harnessId) {
+ checkArgument(harnessId != null, "harnessId must not be null");
+ Metadata md = new Metadata();
+ md.put(ID_KEY, harnessId);
+ return MetadataUtils.newAttachHeadersInterceptor(md);
+ }
+
+ // This is implemented via MetadataUtils, so we never actually create an instance of this class
+ private AddHarnessIdInterceptor() {}
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BeamFnControlClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BeamFnControlClient.java
index ab932c797d4a..7c0ed198a4ee 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BeamFnControlClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BeamFnControlClient.java
@@ -20,8 +20,8 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
+import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Uninterruptibles;
-import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.EnumMap;
@@ -31,11 +31,14 @@
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.BiFunction;
-import java.util.function.Function;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest.RequestCase;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse.Builder;
import org.apache.beam.model.fnexecution.v1.BeamFnControlGrpc;
-import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
+import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.fn.function.ThrowingFunction;
import org.apache.beam.sdk.fn.stream.StreamObserverFactory.StreamObserverClientFactory;
import org.slf4j.Logger;
@@ -69,21 +72,23 @@ public class BeamFnControlClient {
private final CompletableFuture