Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

Commit

Permalink
Deprecate shared threadpool and evenly split it between workers if th…
Browse files Browse the repository at this point in the history
…at is supplied
  • Loading branch information
jxu-nflx committed Oct 28, 2022
1 parent 312a1f7 commit 13a68e7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class TaskPollExecutor {
TaskPollExecutor(
EurekaClient eurekaClient,
TaskClient taskClient,
int threadCount,
int updateRetryCount,
Map<String, String> taskToDomain,
String workerNamePrefix,
Expand All @@ -81,17 +80,11 @@ class TaskPollExecutor {

this.pollingSemaphoreMap = new HashMap<>();
int totalThreadCount = 0;
if (!taskThreadCount.isEmpty()) {
for (Map.Entry<String, Integer> entry : taskThreadCount.entrySet()) {
String taskType = entry.getKey();
int count = entry.getValue();
totalThreadCount += count;
pollingSemaphoreMap.put(taskType, new PollingSemaphore(count));
}
} else {
totalThreadCount = threadCount;
// shared poll for all workers
pollingSemaphoreMap.put(ALL_WORKERS, new PollingSemaphore(threadCount));
for (Map.Entry<String, Integer> entry : taskThreadCount.entrySet()) {
String taskType = entry.getKey();
int count = entry.getValue();
totalThreadCount += count;
pollingSemaphoreMap.put(taskType, new PollingSemaphore(count));
}

LOGGER.info("Initialized the TaskPollExecutor with {} threads", totalThreadCount);
Expand Down Expand Up @@ -163,8 +156,8 @@ void pollAndExecute(Worker worker) {
() ->
taskClient.batchPollTasksInDomain(
taskType,
worker.getIdentity(),
domain,
worker.getIdentity(),
slotsToAcquire,
worker.getBatchPollTimeoutInMS()));
acquiredTasks = tasks.size();
Expand Down Expand Up @@ -400,11 +393,7 @@ private void handleException(Throwable t, TaskResult result, Worker worker, Task
}

private PollingSemaphore getPollingSemaphore(String taskType) {
if (pollingSemaphoreMap.containsKey(taskType)) {
return pollingSemaphoreMap.get(taskType);
} else {
return pollingSemaphoreMap.get(ALL_WORKERS);
}
return pollingSemaphoreMap.get(taskType);
}

private Runnable extendLease(Task task, CompletableFuture<Task> taskCompletableFuture) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
package com.netflix.conductor.client.automator;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
Expand All @@ -44,7 +47,7 @@ public class TaskRunnerConfigurer {
private final List<Worker> workers = new LinkedList<>();
private final int sleepWhenRetry;
private final int updateRetryCount;
private final int threadCount;
@Deprecated private final int threadCount;
private final int shutdownGracePeriodSeconds;
private final String workerNamePrefix;
private final Map<String /*taskType*/, String /*domain*/> taskToDomain;
Expand All @@ -64,19 +67,26 @@ private TaskRunnerConfigurer(Builder builder) {
} else if (!builder.taskThreadCount.isEmpty()) {
for (Worker worker : builder.workers) {
if (!builder.taskThreadCount.containsKey(worker.getTaskDefName())) {
String message =
String.format(MISSING_TASK_THREAD_COUNT, worker.getTaskDefName());
LOGGER.error(message);
throw new ConductorClientException(message);
LOGGER.info(
"No thread count specified for task type {}, default to 1 thread",
worker.getTaskDefName());
builder.taskThreadCount.put(worker.getTaskDefName(), 1);
}
workers.add(worker);
}
this.taskThreadCount = builder.taskThreadCount;
this.threadCount = -1;
} else {
builder.workers.forEach(workers::add);
this.taskThreadCount = builder.taskThreadCount;
Set<String> taskTypes = new HashSet<>();
for (Worker worker : builder.workers) {
taskTypes.add(worker.getTaskDefName());
workers.add(worker);
}
this.threadCount = (builder.threadCount == -1) ? workers.size() : builder.threadCount;
// shared thread pool will be evenly split between task types
int splitThreadCount = threadCount / taskTypes.size();
this.taskThreadCount =
taskTypes.stream().collect(Collectors.toMap(v -> v, v -> splitThreadCount));
}

this.eurekaClient = builder.eurekaClient;
Expand All @@ -94,7 +104,7 @@ public static class Builder {
private String workerNamePrefix = "workflow-worker-%d";
private int sleepWhenRetry = 500;
private int updateRetryCount = 3;
private int threadCount = -1;
@Deprecated private int threadCount = -1;
private int shutdownGracePeriodSeconds = 10;
private final Iterable<Worker> workers;
private EurekaClient eurekaClient;
Expand Down Expand Up @@ -143,7 +153,9 @@ public Builder withUpdateRetryCount(int updateRetryCount) {
* @param threadCount # of threads assigned to the workers. Should be at-least the size of
* taskWorkers to avoid starvation in a busy system.
* @return Builder instance
* @deprecated Use {@link TaskRunnerConfigurer.Builder#withTaskThreadCount(Map)} instead.
*/
@Deprecated
public Builder withThreadCount(int threadCount) {
if (threadCount < 1) {
throw new IllegalArgumentException("No. of threads cannot be less than 1");
Expand Down Expand Up @@ -200,6 +212,7 @@ public TaskRunnerConfigurer build() {
/**
* @return Thread Count for the shared executor pool
*/
@Deprecated
public int getThreadCount() {
return threadCount;
}
Expand Down Expand Up @@ -249,7 +262,6 @@ public synchronized void init() {
new TaskPollExecutor(
eurekaClient,
taskClient,
threadCount,
updateRetryCount,
taskToDomain,
workerNamePrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -48,6 +49,9 @@ public class TaskPollExecutorTest {

private static final String TEST_TASK_DEF_NAME = "test";

private static final Map<String, Integer> TASK_THREAD_MAP =
Collections.singletonMap(TEST_TASK_DEF_NAME, 1);

@Test
public void testTaskExecutionException() throws InterruptedException {
Worker worker =
Expand All @@ -59,7 +63,7 @@ public void testTaskExecutionException() throws InterruptedException {
TaskClient taskClient = Mockito.mock(TaskClient.class);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-%d", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-%d", TASK_THREAD_MAP);

when(taskClient.batchPollTasksInDomain(any(), any(), any(), anyInt(), anyInt()))
.thenReturn(Arrays.asList(testTask()));
Expand Down Expand Up @@ -113,7 +117,7 @@ public TaskResult answer(InvocationOnMock invocation)
TaskClient taskClient = Mockito.mock(TaskClient.class);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
when(taskClient.batchPollTasksInDomain(any(), any(), any(), anyInt(), anyInt()))
.thenReturn(Arrays.asList(task));
when(taskClient.ack(any(), any())).thenReturn(true);
Expand Down Expand Up @@ -173,7 +177,7 @@ public void testLargePayloadCanFailUpdateWithRetry() throws InterruptedException

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 3, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -212,7 +216,7 @@ public void testLargePayloadLocationUpdate() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 3, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);

doAnswer(
Expand Down Expand Up @@ -243,7 +247,7 @@ public void testTaskPollException() throws InterruptedException {

Worker worker = mock(Worker.class);
when(worker.getPollingInterval()).thenReturn(3000);
when(worker.getTaskDefName()).thenReturn("test");
when(worker.getTaskDefName()).thenReturn(TEST_TASK_DEF_NAME);
when(worker.execute(any())).thenReturn(new TaskResult(task));

TaskClient taskClient = Mockito.mock(TaskClient.class);
Expand All @@ -253,7 +257,7 @@ public void testTaskPollException() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -290,7 +294,7 @@ public void testTaskPoll() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -320,7 +324,7 @@ public void testTaskPollDomain() throws InterruptedException {
taskToDomain.put(TEST_TASK_DEF_NAME, testDomain);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, taskToDomain, "test-worker-", new HashMap<>());
null, taskClient, 1, taskToDomain, "test-worker-", TASK_THREAD_MAP);

String workerName = "test-worker";
Worker worker = mock(Worker.class);
Expand Down Expand Up @@ -363,7 +367,12 @@ public void testPollOutOfDiscoveryForTask() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
client,
taskClient,
1,
new HashMap<>(),
"test-worker-",
Collections.singletonMap("task_run_always", 1));
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -404,7 +413,7 @@ public void testPollOutOfDiscoveryAsDefaultFalseForTask()

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
client, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -446,7 +455,7 @@ public void testPollOutOfDiscoveryAsExplicitFalseForTask()

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
client, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -487,7 +496,12 @@ public void testPollOutOfDiscoveryIsIgnoredWhenDiscoveryIsUp() throws Interrupte

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
client,
taskClient,
1,
new HashMap<>(),
"test-worker-",
Collections.singletonMap("task_ignore_override", 1));
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -518,7 +532,7 @@ public void testTaskThreadCount() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, -1, 1, new HashMap<>(), "test-worker-", taskThreadCount);
null, taskClient, -1, new HashMap<>(), "test-worker-", taskThreadCount);

String workerName = "test-worker";
Worker worker = mock(Worker.class);
Expand Down Expand Up @@ -563,7 +577,7 @@ public void testTaskLeaseExtend() throws InterruptedException {

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
null, taskClient, 1, new HashMap<>(), "test-worker-", TASK_THREAD_MAP);
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -619,13 +633,13 @@ public TaskResult answer(InvocationOnMock invocation)
});
}
when(taskClient.batchPollTasksInDomain(
TEST_TASK_DEF_NAME, workerName, null, threadCount, 1000))
TEST_TASK_DEF_NAME, null, workerName, threadCount, 1000))
.thenReturn(tasks);
when(taskClient.ack(any(), any())).thenReturn(true);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, -1, 1, new HashMap<>(), "test-worker-", taskThreadCount);
null, taskClient, 1, new HashMap<>(), "test-worker-", taskThreadCount);

CountDownLatch latch = new CountDownLatch(threadCount);
doAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import static com.netflix.conductor.common.metadata.tasks.TaskResult.Status.COMPLETED;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -69,15 +69,21 @@ public void testInvalidThreadConfig() {
.build();
}

@Test(expected = ConductorClientException.class)
@Test
public void testMissingTaskThreadConfig() {
Worker worker1 = Worker.create("task1", TaskResult::new);
Worker worker2 = Worker.create("task2", TaskResult::new);
Map<String, Integer> taskThreadCount = new HashMap<>();
taskThreadCount.put(worker1.getTaskDefName(), 2);
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withTaskThreadCount(taskThreadCount)
.build();
TaskRunnerConfigurer configurer =
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withTaskThreadCount(taskThreadCount)
.build();

assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(2, configurer.getTaskThreadCount().size());
assertEquals(2, configurer.getTaskThreadCount().get("task1").intValue());
assertEquals(1, configurer.getTaskThreadCount().get("task2").intValue());
}

@Test
Expand Down Expand Up @@ -108,7 +114,9 @@ public void testSharedThreadPool() {
assertEquals(500, configurer.getSleepWhenRetry());
assertEquals(3, configurer.getUpdateRetryCount());
assertEquals(10, configurer.getShutdownGracePeriodSeconds());
assertTrue(configurer.getTaskThreadCount().isEmpty());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(3, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());

configurer =
new TaskRunnerConfigurer.Builder(client, Collections.singletonList(worker))
Expand All @@ -125,7 +133,9 @@ public void testSharedThreadPool() {
assertEquals(10, configurer.getUpdateRetryCount());
assertEquals(15, configurer.getShutdownGracePeriodSeconds());
assertEquals("test-worker-", configurer.getWorkerNamePrefix());
assertTrue(configurer.getTaskThreadCount().isEmpty());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(100, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());
}

@Test
Expand Down

0 comments on commit 13a68e7

Please sign in to comment.