Skip to content

Commit

Permalink
[BEAM-13015] Use a DirectExecutor for state since we are just complet…
Browse files Browse the repository at this point in the history
…ing futures on the callback. (#16745)
  • Loading branch information
lukecwik committed Feb 26, 2022
1 parent 9ef3247 commit 939af65
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService;
import org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.fn.server.GrpcFnServer;
import org.apache.beam.sdk.fn.server.InProcessServerFactory;
import org.apache.beam.sdk.fn.server.ServerFactory;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
import org.apache.beam.sdk.options.PipelineOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -106,7 +106,7 @@ public RemoteEnvironment createEnvironment(Environment environment, String worke
loggingServer.getApiServiceDescriptor(),
controlServer.getApiServiceDescriptor(),
null,
InProcessManagedChannelFactory.create(),
ManagedChannelFactory.createInProcess(),
OutboundObserverFactory.clientDirect(),
Caches.fromOptions(options));
} catch (NoClassDefFoundError e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.server.GrpcContextHeaderAccessorProvider;
import org.apache.beam.sdk.fn.server.GrpcFnServer;
import org.apache.beam.sdk.fn.server.InProcessServerFactory;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
Expand Down Expand Up @@ -219,7 +219,7 @@ public void launchSdkHarness(PipelineOptions options) throws Exception {
loggingServer.getApiServiceDescriptor(),
controlServer.getApiServiceDescriptor(),
null,
InProcessManagedChannelFactory.create(),
ManagedChannelFactory.createInProcess(),
OutboundObserverFactory.clientDirect(),
Caches.eternal());
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.apache.beam.runners.portability.testing.TestJobService;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.PipelineResult.State;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.metrics.MetricQueryResults;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
Expand Down Expand Up @@ -94,7 +94,7 @@ public class PortableRunnerTest implements Serializable {
@Test
public void stagesAndRunsJob() throws Exception {
createJobServer(JobState.Enum.DONE, JobApi.MetricResults.getDefaultInstance());
PortableRunner runner = PortableRunner.create(options, InProcessManagedChannelFactory.create());
PortableRunner runner = PortableRunner.create(options, ManagedChannelFactory.createInProcess());
State state = runner.run(p).waitUntilFinish();
assertThat(state, is(State.DONE));
}
Expand All @@ -103,7 +103,7 @@ public void stagesAndRunsJob() throws Exception {
public void extractsMetrics() throws Exception {
JobApi.MetricResults metricResults = generateMetricResults();
createJobServer(JobState.Enum.DONE, metricResults);
PortableRunner runner = PortableRunner.create(options, InProcessManagedChannelFactory.create());
PortableRunner runner = PortableRunner.create(options, ManagedChannelFactory.createInProcess());
PipelineResult result = runner.run(p);
result.waitUntilFinish();
MetricQueryResults metricQueryResults = result.metrics().allMetrics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,104 +17,122 @@
*/
package org.apache.beam.sdk.fn.channel;

import avro.shaded.com.google.common.collect.ImmutableList;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.List;
import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ClientInterceptor;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannelBuilder;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.netty.NettyChannelBuilder;
import org.apache.beam.vendor.grpc.v1p43p2.io.netty.channel.epoll.EpollDomainSocketChannel;
import org.apache.beam.vendor.grpc.v1p43p2.io.netty.channel.epoll.EpollEventLoopGroup;
import org.apache.beam.vendor.grpc.v1p43p2.io.netty.channel.epoll.EpollSocketChannel;
import org.apache.beam.vendor.grpc.v1p43p2.io.netty.channel.unix.DomainSocketAddress;

/** A Factory which creates an underlying {@link ManagedChannel} implementation. */
public abstract class ManagedChannelFactory {
/** A Factory which creates {@link ManagedChannel} instances. */
public class ManagedChannelFactory {
/**
* Creates a {@link ManagedChannel} relying on the {@link ManagedChannelBuilder} to choose the
* channel type.
*/
public static ManagedChannelFactory createDefault() {
return new Default();
return new ManagedChannelFactory(Type.DEFAULT, Collections.emptyList(), false);
}

/**
* Creates a {@link ManagedChannelFactory} backed by an {@link EpollDomainSocketChannel} if the
* address is a {@link DomainSocketAddress}. Otherwise creates a {@link ManagedChannel} backed by
* an {@link EpollSocketChannel}.
*/
public static ManagedChannelFactory createEpoll() {
org.apache.beam.vendor.grpc.v1p43p2.io.netty.channel.epoll.Epoll.ensureAvailability();
return new Epoll();
return new ManagedChannelFactory(Type.EPOLL, Collections.emptyList(), false);
}

public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
return builderFor(apiServiceDescriptor).build();
/** Creates a {@link ManagedChannel} using an in-process channel. */
public static ManagedChannelFactory createInProcess() {
return new ManagedChannelFactory(Type.IN_PROCESS, Collections.emptyList(), false);
}

/** Create a {@link ManagedChannelBuilder} for the provided {@link ApiServiceDescriptor}. */
protected abstract ManagedChannelBuilder<?> builderFor(ApiServiceDescriptor descriptor);
public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
ManagedChannelBuilder<?> channelBuilder;
switch (type) {
case EPOLL:
SocketAddress address = SocketAddressFactory.createFrom(apiServiceDescriptor.getUrl());
channelBuilder =
NettyChannelBuilder.forAddress(address)
.channelType(
address instanceof DomainSocketAddress
? EpollDomainSocketChannel.class
: EpollSocketChannel.class)
.eventLoopGroup(new EpollEventLoopGroup());
break;

/**
* Returns a {@link ManagedChannelFactory} like this one, but which will apply the provided {@link
* ClientInterceptor ClientInterceptors} to any channel it creates.
*/
public ManagedChannelFactory withInterceptors(List<ClientInterceptor> interceptors) {
return new InterceptedManagedChannelFactory(this, interceptors);
}
case DEFAULT:
channelBuilder = ManagedChannelBuilder.forTarget(apiServiceDescriptor.getUrl());
break;

/**
* Creates a {@link ManagedChannel} backed by an {@link EpollDomainSocketChannel} if the address
* is a {@link DomainSocketAddress}. Otherwise creates a {@link ManagedChannel} backed by an
* {@link EpollSocketChannel}.
*/
private static class Epoll extends ManagedChannelFactory {
@Override
public ManagedChannelBuilder<?> builderFor(ApiServiceDescriptor apiServiceDescriptor) {
SocketAddress address = SocketAddressFactory.createFrom(apiServiceDescriptor.getUrl());
return NettyChannelBuilder.forAddress(address)
.channelType(
address instanceof DomainSocketAddress
? EpollDomainSocketChannel.class
: EpollSocketChannel.class)
.eventLoopGroup(new EpollEventLoopGroup())
.usePlaintext()
// Set the message size to max value here. The actual size is governed by the
// buffer size in the layers above.
.maxInboundMessageSize(Integer.MAX_VALUE);
case IN_PROCESS:
channelBuilder = InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl());
break;

default:
throw new IllegalStateException("Unknown type " + type);
}
}

/**
* Creates a {@link ManagedChannel} relying on the {@link ManagedChannelBuilder} to create
* instances.
*/
private static class Default extends ManagedChannelFactory {
@Override
public ManagedChannelBuilder<?> builderFor(ApiServiceDescriptor apiServiceDescriptor) {
return ManagedChannelBuilder.forTarget(apiServiceDescriptor.getUrl())
.usePlaintext()
// Set the message size to max value here. The actual size is governed by the
// buffer size in the layers above.
.maxInboundMessageSize(Integer.MAX_VALUE);
channelBuilder =
channelBuilder
.usePlaintext()
// Set the message size to max value here. The actual size is governed by the
// buffer size in the layers above.
.maxInboundMessageSize(Integer.MAX_VALUE)
.intercept(interceptors);
if (directExecutor) {
channelBuilder = channelBuilder.directExecutor();
}
return channelBuilder.build();
}

private static class InterceptedManagedChannelFactory extends ManagedChannelFactory {
private final ManagedChannelFactory channelFactory;
private final List<ClientInterceptor> interceptors;
/** The channel type. */
private enum Type {
EPOLL,
DEFAULT,
IN_PROCESS,
}

private InterceptedManagedChannelFactory(
ManagedChannelFactory managedChannelFactory, List<ClientInterceptor> interceptors) {
this.channelFactory = managedChannelFactory;
this.interceptors = interceptors;
}
private final Type type;
private final List<ClientInterceptor> interceptors;
private final boolean directExecutor;

@Override
public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) {
return builderFor(apiServiceDescriptor).intercept(interceptors).build();
}
private ManagedChannelFactory(
Type type, List<ClientInterceptor> interceptors, boolean directExecutor) {
this.type = type;
this.interceptors = interceptors;
this.directExecutor = directExecutor;
}

@Override
protected ManagedChannelBuilder<?> builderFor(ApiServiceDescriptor descriptor) {
return channelFactory.builderFor(descriptor);
}
/**
* Returns a {@link ManagedChannelFactory} like this one, but which will apply the provided {@link
* ClientInterceptor ClientInterceptors} to any channel it creates.
*/
public ManagedChannelFactory withInterceptors(List<ClientInterceptor> interceptors) {
return new ManagedChannelFactory(
type,
ImmutableList.<ClientInterceptor>builder()
.addAll(this.interceptors)
.addAll(interceptors)
.build(),
directExecutor);
}

@Override
public ManagedChannelFactory withInterceptors(List<ClientInterceptor> interceptors) {
return new InterceptedManagedChannelFactory(channelFactory, interceptors);
}
/**
* Returns a {@link ManagedChannelFactory} like this one, but will construct the channel to use
* the direct executor.
*/
public ManagedChannelFactory withDirectExecutor() {
return new ManagedChannelFactory(type, interceptors, true);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
import org.apache.beam.runners.core.metrics.SimpleExecutionState;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.inprocess.InProcessServerBuilder;
Expand Down Expand Up @@ -89,7 +88,7 @@ public ManageLoggingClientAndService() {
ApiServiceDescriptor.newBuilder()
.setUrl(BeamFnLoggingClientBenchmark.class.getName() + "#" + UUID.randomUUID())
.build();
ManagedChannelFactory managedChannelFactory = InProcessManagedChannelFactory.create();
ManagedChannelFactory managedChannelFactory = ManagedChannelFactory.createInProcess();
loggingService = new CallCountLoggingService();
server =
InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,32 +994,30 @@ public void onClaimFailed(PositionT position) {}
/** Internal class to hold the primary and residual roots when converted to an input element. */
@AutoValue
@AutoValue.CopyAnnotations
@SuppressWarnings({"rawtypes"})
abstract static class WindowedSplitResult {
public static WindowedSplitResult forRoots(
WindowedValue primaryInFullyProcessedWindowsRoot,
WindowedValue primarySplitRoot,
WindowedValue residualSplitRoot,
WindowedValue residualInUnprocessedWindowsRoot) {
WindowedValue<?> primaryInFullyProcessedWindowsRoot,
WindowedValue<?> primarySplitRoot,
WindowedValue<?> residualSplitRoot,
WindowedValue<?> residualInUnprocessedWindowsRoot) {
return new AutoValue_FnApiDoFnRunner_WindowedSplitResult(
primaryInFullyProcessedWindowsRoot,
primarySplitRoot,
residualSplitRoot,
residualInUnprocessedWindowsRoot);
}

public abstract @Nullable WindowedValue getPrimaryInFullyProcessedWindowsRoot();
public abstract @Nullable WindowedValue<?> getPrimaryInFullyProcessedWindowsRoot();

public abstract @Nullable WindowedValue getPrimarySplitRoot();
public abstract @Nullable WindowedValue<?> getPrimarySplitRoot();

public abstract @Nullable WindowedValue getResidualSplitRoot();
public abstract @Nullable WindowedValue<?> getResidualSplitRoot();

public abstract @Nullable WindowedValue getResidualInUnprocessedWindowsRoot();
public abstract @Nullable WindowedValue<?> getResidualInUnprocessedWindowsRoot();
}

@AutoValue
@AutoValue.CopyAnnotations
@SuppressWarnings({"rawtypes"})
abstract static class SplitResultsWithStopIndex {
public static SplitResultsWithStopIndex of(
WindowedSplitResult windowSplit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ public static void main(
new BeamFnDataGrpcClient(options, channelFactory::forDescriptor, outboundObserverFactory);

BeamFnStateGrpcClientCache beamFnStateGrpcClientCache =
new BeamFnStateGrpcClientCache(
idGenerator, channelFactory::forDescriptor, outboundObserverFactory);
new BeamFnStateGrpcClientCache(idGenerator, channelFactory, outboundObserverFactory);

FinalizeBundleHandler finalizeBundleHandler =
new FinalizeBundleHandler(options.as(GcsOptions.class).getExecutorService());
Expand Down
Loading

0 comments on commit 939af65

Please sign in to comment.