diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/PerJobMiniClusterFactory.java b/flink-clients/src/main/java/org/apache/flink/client/program/PerJobMiniClusterFactory.java index 56b85fde50879..7dc4585e58d12 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/PerJobMiniClusterFactory.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/PerJobMiniClusterFactory.java @@ -20,6 +20,7 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.PipelineOptions; import org.apache.flink.configuration.RestOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.execution.JobClient; @@ -34,6 +35,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Function; @@ -127,14 +129,24 @@ private MiniClusterConfiguration getMiniClusterConfig(int maximumParallelism) { ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, ConfigConstants.DEFAULT_LOCAL_NUMBER_TASK_MANAGER); + Map overwriteParallelisms = + configuration.get(PipelineOptions.PARALLELISM_OVERRIDES); + if (overwriteParallelisms != null) { + for (String overrideParallelism : overwriteParallelisms.values()) { + maximumParallelism = + Math.max(maximumParallelism, Integer.parseInt(overrideParallelism)); + } + } + + int finalMaximumParallelism = maximumParallelism; int numSlotsPerTaskManager = configuration .getOptional(TaskManagerOptions.NUM_TASK_SLOTS) .orElseGet( () -> - maximumParallelism > 0 + finalMaximumParallelism > 0 ? MathUtils.divideRoundUp( - maximumParallelism, numTaskManagers) + finalMaximumParallelism, numTaskManagers) : TaskManagerOptions.NUM_TASK_SLOTS.defaultValue()); return new MiniClusterConfiguration.Builder() diff --git a/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java b/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java index a439d10114a10..b48d867c8e9e4 100644 --- a/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java +++ b/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java @@ -21,21 +21,28 @@ import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.JobStatus; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.PipelineOptions; import org.apache.flink.core.execution.JobClient; import org.apache.flink.core.execution.SavepointFormatType; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobGraphTestUtils; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.testutils.CancelableInvokable; import org.apache.flink.runtime.testutils.WaitingCancelableInvokable; +import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableMap; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import java.time.Duration; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import static org.apache.flink.util.Preconditions.checkState; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -152,9 +159,44 @@ void testJobClientInteractionAfterShutdown() throws Exception { "MiniCluster is not yet running or has already been shut down."); } + @Test + void testTurnUpParallelismByOverwriteParallelism() throws Exception { + JobVertex jobVertex = getBlockingJobVertex(); + JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertex); + int overwriteParallelism = jobVertex.getParallelism() + 1; + BlockingInvokable.reset(overwriteParallelism); + + Configuration configuration = new Configuration(); + configuration.set( + PipelineOptions.PARALLELISM_OVERRIDES, + ImmutableMap.of( + jobVertex.getID().toHexString(), String.valueOf(overwriteParallelism))); + + PerJobMiniClusterFactory perJobMiniClusterFactory = initializeMiniCluster(configuration); + JobClient jobClient = + perJobMiniClusterFactory + .submitJob(jobGraph, ClassLoader.getSystemClassLoader()) + .get(); + + // wait for tasks to be properly running + BlockingInvokable.latch.await(); + + jobClient.cancel().get(); + assertThat(jobClient.getJobExecutionResult()) + .failsWithin(Duration.ofSeconds(1)) + .withThrowableOfType(ExecutionException.class) + .withMessageContaining("Job was cancelled"); + + assertThatMiniClusterIsShutdown(); + } + private PerJobMiniClusterFactory initializeMiniCluster() { + return initializeMiniCluster(new Configuration()); + } + + private PerJobMiniClusterFactory initializeMiniCluster(Configuration configuration) { return PerJobMiniClusterFactory.createWithFactory( - new Configuration(), + configuration, config -> { miniCluster = new MiniCluster(config); return miniCluster; @@ -175,4 +217,32 @@ private static JobGraph getCancellableJobGraph() { jobVertex.setParallelism(1); return JobGraphTestUtils.streamingJobGraph(jobVertex); } + + private static JobVertex getBlockingJobVertex() { + JobVertex jobVertex = new JobVertex("jobVertex"); + jobVertex.setInvokableClass(BlockingInvokable.class); + jobVertex.setParallelism(2); + return jobVertex; + } + + /** Test invokable that allows waiting for all subtasks to be running. */ + public static class BlockingInvokable extends CancelableInvokable { + + private static CountDownLatch latch; + + public BlockingInvokable(Environment environment) { + super(environment); + } + + @Override + public void doInvoke() throws Exception { + checkState(latch != null, "The invokable should be reset first."); + latch.countDown(); + waitUntilCancelled(); + } + + public static void reset(int count) { + latch = new CountDownLatch(count); + } + } }