Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-8670] Manage environment parallelism in DefaultJobBundleFactory #10124

Merged
merged 1 commit into from
Nov 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
import org.apache.beam.runners.fnexecution.control.DefaultExecutableStageContext.MultiInstanceFactory;
import org.apache.beam.runners.fnexecution.control.DefaultExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.ReferenceCountingExecutableStageContextFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.flink.api.java.ExecutionEnvironment;

/** Singleton class that contains one {@link MultiInstanceFactory} per job. */
/** Singleton class that contains one {@link ExecutableStageContext.Factory} per job. */
public class FlinkExecutableStageContextFactory implements ExecutableStageContext.Factory {

private static final FlinkExecutableStageContextFactory instance =
Expand All @@ -36,7 +34,7 @@ public class FlinkExecutableStageContextFactory implements ExecutableStageContex
// classloader and therefore its own instance of FlinkExecutableStageContextFactory. This
// code supports multiple JobInfos in order to provide a sensible implementation of
// Factory.get(JobInfo), which in theory could be called with different JobInfos.
private static final ConcurrentMap<String, MultiInstanceFactory> jobFactories =
private static final ConcurrentMap<String, ExecutableStageContext.Factory> jobFactories =
new ConcurrentHashMap<>();

private FlinkExecutableStageContextFactory() {}
Expand All @@ -47,17 +45,12 @@ public static FlinkExecutableStageContextFactory getInstance() {

@Override
public ExecutableStageContext get(JobInfo jobInfo) {
MultiInstanceFactory jobFactory =
ExecutableStageContext.Factory jobFactory =
jobFactories.computeIfAbsent(
jobInfo.jobId(),
k -> {
PortablePipelineOptions portableOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
.as(PortablePipelineOptions.class);

return new MultiInstanceFactory(
MoreObjects.firstNonNull(portableOptions.getSdkWorkerParallelism(), 1L)
.intValue(),
return ReferenceCountingExecutableStageContextFactory.create(
DefaultExecutableStageContext::create,
// Clean up context immediately if its class is not loaded on Flink parent
// classloader.
(caller) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
*/
package org.apache.beam.runners.fnexecution.control;

import java.util.ArrayList;
import java.util.List;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

/** Implementation of a {@link ExecutableStageContext}. */
public class DefaultExecutableStageContext implements ExecutableStageContext, AutoCloseable {
private final JobBundleFactory jobBundleFactory;

private static DefaultExecutableStageContext create(JobInfo jobInfo) {
public static DefaultExecutableStageContext create(JobInfo jobInfo) {
JobBundleFactory jobBundleFactory = DefaultJobBundleFactory.create(jobInfo);
return new DefaultExecutableStageContext(jobBundleFactory);
}
Expand All @@ -46,54 +42,4 @@ public StageBundleFactory getStageBundleFactory(ExecutableStage executableStage)
public void close() throws Exception {
jobBundleFactory.close();
}

/**
* {@link ExecutableStageContext.Factory} that creates and round-robins between a number of child
* {@link ExecutableStageContext.Factory} instances.
*/
public static class MultiInstanceFactory implements ExecutableStageContext.Factory {

private int index = 0;
private final List<ReferenceCountingExecutableStageContextFactory> factories =
new ArrayList<>();
private final int maxFactories;
private final SerializableFunction<Object, Boolean> isReleaseSynchronous;

public MultiInstanceFactory(
int maxFactories, SerializableFunction<Object, Boolean> isReleaseSynchronous) {
this.isReleaseSynchronous = isReleaseSynchronous;
Preconditions.checkArgument(maxFactories >= 0, "sdk_worker_parallelism must be >= 0");

if (maxFactories == 0) {
// if this is 0, use the auto behavior of num_cores - 1 so that we leave some resources
// available for the java process
this.maxFactories = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
} else {
this.maxFactories = maxFactories;
}
}

private synchronized ExecutableStageContext.Factory getFactory() {
ReferenceCountingExecutableStageContextFactory factory;
// If we haven't yet created maxFactories factories, create a new one. Otherwise use an
// existing one from factories.
if (factories.size() < maxFactories) {
factory =
ReferenceCountingExecutableStageContextFactory.create(
DefaultExecutableStageContext::create, isReleaseSynchronous);
factories.add(factory);
} else {
factory = factories.get(index);
}

index = (index + 1) % maxFactories;

return factory;
}

@Override
public ExecutableStageContext get(JobInfo jobInfo) {
return getFactory().get(jobInfo);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.sdk.options.PortablePipelineOptions.RetrievalServiceType;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalNotification;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.slf4j.Logger;
Expand All @@ -85,7 +87,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
private static final IdGenerator factoryIdGenerator = IdGenerators.incrementingLongs();

private final String factoryId = factoryIdGenerator.getId();
private final LoadingCache<Environment, WrappedSdkHarnessClient> environmentCache;
private final ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> environmentCaches;
private final AtomicInteger stageBundleCount = new AtomicInteger();
private final Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap;
private final ExecutorService executor;
private final MapControlClientPool clientPool;
Expand Down Expand Up @@ -121,8 +124,10 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = () -> factoryId + "-" + stageIdSuffixGenerator.getId();
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.environmentCache =
createEnvironmentCache(serverFactory -> createServerInfo(jobInfo, serverFactory));
this.environmentCaches =
createEnvironmentCaches(
serverFactory -> createServerInfo(jobInfo, serverFactory),
getMaxEnvironmentClients(jobInfo));
}

@VisibleForTesting
Expand All @@ -136,17 +141,12 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = stageIdGenerator;
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.environmentCache = createEnvironmentCache(serverFactory -> serverInfo);
this.environmentCaches =
createEnvironmentCaches(serverFactory -> serverInfo, getMaxEnvironmentClients(jobInfo));
}

private static int getEnvironmentExpirationMillis(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
return pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis();
}

private LoadingCache<Environment, WrappedSdkHarnessClient> createEnvironmentCache(
ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator) {
private ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> createEnvironmentCaches(
ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator, int count) {
CacheBuilder builder =
CacheBuilder.newBuilder()
.removalListener(
Expand All @@ -161,26 +161,55 @@ private LoadingCache<Environment, WrappedSdkHarnessClient> createEnvironmentCach
if (environmentExpirationMillis > 0) {
builder = builder.expireAfterWrite(environmentExpirationMillis, TimeUnit.MILLISECONDS);
}
return builder.build(
new CacheLoader<Environment, WrappedSdkHarnessClient>() {
@Override
public WrappedSdkHarnessClient load(Environment environment) throws Exception {
EnvironmentFactory.Provider environmentFactoryProvider =
environmentFactoryProviderMap.get(environment.getUrn());
ServerFactory serverFactory = environmentFactoryProvider.getServerFactory();
ServerInfo serverInfo = serverInfoCreator.apply(serverFactory);
EnvironmentFactory environmentFactory =
environmentFactoryProvider.createEnvironmentFactory(
serverInfo.getControlServer(),
serverInfo.getLoggingServer(),
serverInfo.getRetrievalServer(),
serverInfo.getProvisioningServer(),
clientPool,
stageIdGenerator);
return WrappedSdkHarnessClient.wrapping(
environmentFactory.createEnvironment(environment), serverInfo);
}
});

ImmutableList.Builder<LoadingCache<Environment, WrappedSdkHarnessClient>> caches =
ImmutableList.builder();
for (int i = 0; i < count; i++) {
LoadingCache<Environment, WrappedSdkHarnessClient> cache =
builder.build(
new CacheLoader<Environment, WrappedSdkHarnessClient>() {
@Override
public WrappedSdkHarnessClient load(Environment environment) throws Exception {
EnvironmentFactory.Provider environmentFactoryProvider =
environmentFactoryProviderMap.get(environment.getUrn());
ServerFactory serverFactory = environmentFactoryProvider.getServerFactory();
ServerInfo serverInfo = serverInfoCreator.apply(serverFactory);
EnvironmentFactory environmentFactory =
environmentFactoryProvider.createEnvironmentFactory(
serverInfo.getControlServer(),
serverInfo.getLoggingServer(),
serverInfo.getRetrievalServer(),
serverInfo.getProvisioningServer(),
clientPool,
stageIdGenerator);
return WrappedSdkHarnessClient.wrapping(
environmentFactory.createEnvironment(environment), serverInfo);
}
});
caches.add(cache);
}
return caches.build();
}

private static int getEnvironmentExpirationMillis(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
return pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis();
}

private static int getMaxEnvironmentClients(JobInfo jobInfo) {
PortablePipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
.as(PortablePipelineOptions.class);
int maxEnvironments =
MoreObjects.firstNonNull(pipelineOptions.getSdkWorkerParallelism(), 1L).intValue();
Preconditions.checkArgument(maxEnvironments >= 0, "sdk_worker_parallelism must be >= 0");
if (maxEnvironments == 0) {
// if this is 0, use the auto behavior of num_cores - 1 so that we leave some resources
// available for the java process
maxEnvironments = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
}
return maxEnvironments;
}

@Override
Expand All @@ -192,9 +221,10 @@ public StageBundleFactory forStage(ExecutableStage executableStage) {
public void close() throws Exception {
// Clear the cache. This closes all active environments.
// note this may cause open calls to be cancelled by the peer
environmentCache.invalidateAll();
environmentCache.cleanUp();

for (LoadingCache<Environment, WrappedSdkHarnessClient> environmentCache : environmentCaches) {
environmentCache.invalidateAll();
environmentCache.cleanUp();
}
executor.shutdown();
}

Expand All @@ -205,13 +235,16 @@ public void close() throws Exception {
private class SimpleStageBundleFactory implements StageBundleFactory {

private final ExecutableStage executableStage;
private final int environmentIndex;
private BundleProcessor processor;
private ExecutableProcessBundleDescriptor processBundleDescriptor;
private WrappedSdkHarnessClient wrappedClient;

private SimpleStageBundleFactory(ExecutableStage executableStage) {
this.executableStage = executableStage;
prepare(environmentCache.getUnchecked(executableStage.getEnvironment()));
this.environmentIndex = stageBundleCount.getAndIncrement() % environmentCaches.size();
prepare(
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()));
}

private void prepare(WrappedSdkHarnessClient wrappedClient) {
Expand Down Expand Up @@ -266,7 +299,7 @@ public RemoteBundle getBundle(
}

final WrappedSdkHarnessClient client =
environmentCache.getUnchecked(executableStage.getEnvironment());
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment());
client.ref();

if (client != wrappedClient) {
Expand Down

This file was deleted.

Loading