From 20c43cde737d20c73b8f0d6e248fd4d4d1ec0011 Mon Sep 17 00:00:00 2001 From: PritLadani Date: Tue, 6 Sep 2022 17:16:44 +0530 Subject: [PATCH] Backporting RTF Backporting pull requests https://github.com/opensearch-project/OpenSearch/pull/2089 and https://github.com/opensearch-project/OpenSearch/pull/3982 Signed-off-by: PritLadani --- .../org/opensearch/client/tasks/TaskInfo.java | 18 +- .../core/tasks/GetTaskResponseTests.java | 29 +- .../tasks/CancelTasksResponseTests.java | 3 +- .../TransportRethrottleActionTests.java | 6 +- .../admin/cluster/node/tasks/TasksIT.java | 20 +- .../src/main/java/org/opensearch/Version.java | 3 + .../tasks/list/TransportListTasksAction.java | 13 +- .../action/search/SearchShardTask.java | 5 + .../opensearch/action/search/SearchTask.java | 5 + .../action/support/TransportAction.java | 78 ++- .../org/opensearch/cluster/ClusterModule.java | 2 + .../common/settings/ClusterSettings.java | 4 +- .../util/concurrent/OpenSearchExecutors.java | 46 +- .../common/util/concurrent/ThreadContext.java | 16 +- .../main/java/org/opensearch/node/Node.java | 12 +- .../rest/action/cat/RestTasksAction.java | 2 + .../org/opensearch/tasks/ResourceStats.java | 28 + .../opensearch/tasks/ResourceStatsType.java | 32 + .../opensearch/tasks/ResourceUsageInfo.java | 108 +++ .../opensearch/tasks/ResourceUsageMetric.java | 27 + .../main/java/org/opensearch/tasks/Task.java | 252 ++++++- .../java/org/opensearch/tasks/TaskInfo.java | 62 +- .../org/opensearch/tasks/TaskManager.java | 50 +- .../opensearch/tasks/TaskResourceStats.java | 106 +++ .../tasks/TaskResourceTrackingService.java | 248 +++++++ .../opensearch/tasks/TaskResourceUsage.java | 105 +++ .../opensearch/tasks/ThreadResourceInfo.java | 60 ++ .../AutoQueueAdjustingExecutorBuilder.java | 19 +- .../RunnableTaskExecutionListener.java | 33 + .../threadpool/TaskAwareRunnable.java | 90 +++ .../org/opensearch/threadpool/ThreadPool.java | 22 +- .../transport/RequestHandlerRegistry.java | 4 + .../tasks/RecordingTaskManagerListener.java | 3 + .../node/tasks/ResourceAwareTasksTests.java | 653 ++++++++++++++++++ .../node/tasks/TaskManagerTestCase.java | 17 +- .../admin/cluster/node/tasks/TaskTests.java | 74 +- .../bulk/TransportBulkActionIngestTests.java | 3 +- .../util/concurrent/ThreadContextTests.java | 10 + .../snapshots/SnapshotResiliencyTests.java | 3 + .../tasks/CancelTasksResponseTests.java | 2 +- .../tasks/ListTasksResponseTests.java | 18 +- .../org/opensearch/tasks/TaskInfoTests.java | 89 ++- .../opensearch/tasks/TaskManagerTests.java | 6 +- .../TaskResourceTrackingServiceTests.java | 97 +++ .../test/tasks/MockTaskManager.java | 16 + .../test/tasks/MockTaskManagerListener.java | 3 + .../opensearch/threadpool/TestThreadPool.java | 20 +- 47 files changed, 2432 insertions(+), 90 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/ResourceStats.java create mode 100644 server/src/main/java/org/opensearch/tasks/ResourceStatsType.java create mode 100644 server/src/main/java/org/opensearch/tasks/ResourceUsageInfo.java create mode 100644 server/src/main/java/org/opensearch/tasks/ResourceUsageMetric.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceStats.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceUsage.java create mode 100644 server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java create mode 100644 server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java create mode 100644 server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java create mode 100644 server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java diff --git a/client/rest-high-level/src/main/java/org/opensearch/client/tasks/TaskInfo.java b/client/rest-high-level/src/main/java/org/opensearch/client/tasks/TaskInfo.java index 062fbe56e4ed9..0e56df1c5021a 100644 --- a/client/rest-high-level/src/main/java/org/opensearch/client/tasks/TaskInfo.java +++ b/client/rest-high-level/src/main/java/org/opensearch/client/tasks/TaskInfo.java @@ -56,6 +56,7 @@ public class TaskInfo { private TaskId parentTaskId; private final Map status = new HashMap<>(); private final Map headers = new HashMap<>(); + private final Map resourceStats = new HashMap<>(); public TaskInfo(TaskId taskId) { this.taskId = taskId; @@ -141,6 +142,14 @@ public Map getStatus() { return status; } + void setResourceStats(Map resourceStats) { + this.resourceStats.putAll(resourceStats); + } + + public Map getResourceStats() { + return resourceStats; + } + private void noOpParse(Object s) {} public static final ObjectParser.NamedObjectParser PARSER; @@ -160,6 +169,7 @@ private void noOpParse(Object s) {} parser.declareBoolean(TaskInfo::setCancellable, new ParseField("cancellable")); parser.declareString(TaskInfo::setParentTaskId, new ParseField("parent_task_id")); parser.declareObject(TaskInfo::setHeaders, (p, c) -> p.mapStrings(), new ParseField("headers")); + parser.declareObject(TaskInfo::setResourceStats, (p, c) -> p.map(), new ParseField("resource_stats")); PARSER = (XContentParser p, Void v, String name) -> parser.parse(p, new TaskInfo(new TaskId(name)), null); } @@ -177,7 +187,8 @@ && isCancellable() == taskInfo.isCancellable() && Objects.equals(getDescription(), taskInfo.getDescription()) && Objects.equals(getParentTaskId(), taskInfo.getParentTaskId()) && Objects.equals(status, taskInfo.status) - && Objects.equals(getHeaders(), taskInfo.getHeaders()); + && Objects.equals(getHeaders(), taskInfo.getHeaders()) + && Objects.equals(getResourceStats(), taskInfo.getResourceStats()); } @Override @@ -192,7 +203,8 @@ public int hashCode() { isCancellable(), getParentTaskId(), status, - getHeaders() + getHeaders(), + getResourceStats() ); } @@ -222,6 +234,8 @@ public String toString() { + status + ", headers=" + headers + + ", resource_stats=" + + resourceStats + '}'; } } diff --git a/client/rest-high-level/src/test/java/org/opensearch/client/core/tasks/GetTaskResponseTests.java b/client/rest-high-level/src/test/java/org/opensearch/client/core/tasks/GetTaskResponseTests.java index a14e1169d09fc..7e706c1b87559 100644 --- a/client/rest-high-level/src/test/java/org/opensearch/client/core/tasks/GetTaskResponseTests.java +++ b/client/rest-high-level/src/test/java/org/opensearch/client/core/tasks/GetTaskResponseTests.java @@ -38,6 +38,8 @@ import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.tasks.RawTaskStatus; +import org.opensearch.tasks.TaskResourceStats; +import org.opensearch.tasks.TaskResourceUsage; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; @@ -45,6 +47,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import static org.opensearch.test.AbstractXContentTestCase.xContentTester; @@ -57,7 +60,7 @@ public void testFromXContent() throws IOException { ) .assertEqualsConsumer(this::assertEqualInstances) .assertToXContentEquivalence(true) - .randomFieldsExcludeFilter(field -> field.endsWith("headers") || field.endsWith("status")) + .randomFieldsExcludeFilter(field -> field.endsWith("headers") || field.endsWith("status") || field.contains("resource_stats")) .test(); } @@ -94,7 +97,19 @@ static TaskInfo randomTaskInfo() { Map headers = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap(randomAlphaOfLength(5), randomAlphaOfLength(5)); - return new TaskInfo(taskId, type, action, description, status, startTime, runningTimeNanos, cancellable, parentTaskId, headers); + return new TaskInfo( + taskId, + type, + action, + description, + status, + startTime, + runningTimeNanos, + cancellable, + parentTaskId, + headers, + randomResourceStats() + ); } private static TaskId randomTaskId() { @@ -114,4 +129,14 @@ private static RawTaskStatus randomRawTaskStatus() { throw new IllegalStateException(e); } } + + private static TaskResourceStats randomResourceStats() { + return randomBoolean() ? null : new TaskResourceStats(new HashMap() { + { + for (int i = 0; i < randomInt(5); i++) { + put(randomAlphaOfLength(5), new TaskResourceUsage(randomNonNegativeLong(), randomNonNegativeLong())); + } + } + }); + } } diff --git a/client/rest-high-level/src/test/java/org/opensearch/client/tasks/CancelTasksResponseTests.java b/client/rest-high-level/src/test/java/org/opensearch/client/tasks/CancelTasksResponseTests.java index 102ebb5fcd390..d5a6a9767a923 100644 --- a/client/rest-high-level/src/test/java/org/opensearch/client/tasks/CancelTasksResponseTests.java +++ b/client/rest-high-level/src/test/java/org/opensearch/client/tasks/CancelTasksResponseTests.java @@ -93,7 +93,8 @@ protected CancelTasksResponseTests.ByNodeCancelTasksResponse createServerTestIns randomIntBetween(5, 10), false, new TaskId("node1", randomLong()), - Collections.singletonMap("x-header-of", "some-value") + Collections.singletonMap("x-header-of", "some-value"), + null ) ); } diff --git a/modules/reindex/src/test/java/org/opensearch/index/reindex/TransportRethrottleActionTests.java b/modules/reindex/src/test/java/org/opensearch/index/reindex/TransportRethrottleActionTests.java index 4e6d3401a2f14..30d2e87c1d9d9 100644 --- a/modules/reindex/src/test/java/org/opensearch/index/reindex/TransportRethrottleActionTests.java +++ b/modules/reindex/src/test/java/org/opensearch/index/reindex/TransportRethrottleActionTests.java @@ -130,7 +130,8 @@ public void testRethrottleSuccessfulResponse() { 0, true, new TaskId("test", task.getId()), - Collections.emptyMap() + Collections.emptyMap(), + null ) ); sliceStatuses.add(new BulkByScrollTask.StatusOrException(status)); @@ -165,7 +166,8 @@ public void testRethrottleWithSomeSucceeded() { 0, true, new TaskId("test", task.getId()), - Collections.emptyMap() + Collections.emptyMap(), + null ) ); sliceStatuses.add(new BulkByScrollTask.StatusOrException(status)); diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java index 02034ae05dc03..97eafc8dceb69 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java @@ -478,6 +478,9 @@ public void onTaskUnregistered(Task task) {} @Override public void waitForTaskCompletion(Task task) {} + + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} }); } // Need to run the task in a separate thread because node client's .execute() is blocked by our task listener @@ -659,6 +662,9 @@ public void waitForTaskCompletion(Task task) { waitForWaitingToStart.countDown(); } + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} + @Override public void onTaskRegistered(Task task) {} @@ -901,7 +907,19 @@ public void testNodeNotFoundButTaskFound() throws Exception { TaskResultsService resultsService = internalCluster().getInstance(TaskResultsService.class); resultsService.storeResult( new TaskResult( - new TaskInfo(new TaskId("fake", 1), "test", "test", "", null, 0, 0, false, TaskId.EMPTY_TASK_ID, Collections.emptyMap()), + new TaskInfo( + new TaskId("fake", 1), + "test", + "test", + "", + null, + 0, + 0, + false, + TaskId.EMPTY_TASK_ID, + Collections.emptyMap(), + null + ), new RuntimeException("test") ), new ActionListener() { diff --git a/server/src/main/java/org/opensearch/Version.java b/server/src/main/java/org/opensearch/Version.java index 7401ec4472264..c9ddea6781762 100644 --- a/server/src/main/java/org/opensearch/Version.java +++ b/server/src/main/java/org/opensearch/Version.java @@ -86,6 +86,9 @@ public class Version implements Comparable, ToXContentFragment { public static final Version V_1_3_3 = new Version(1030399, org.apache.lucene.util.Version.LUCENE_8_10_1); public static final Version V_1_3_4 = new Version(1030499, org.apache.lucene.util.Version.LUCENE_8_10_1); public static final Version V_1_3_5 = new Version(1030599, org.apache.lucene.util.Version.LUCENE_8_10_1); + public static final Version V_2_0_0 = new Version(2000099, org.apache.lucene.util.Version.LUCENE_8_10_1); // TODO: Need to change the + // lucene version to + // LUCENE_9_1_0 public static final Version CURRENT = V_1_3_5; public static Version readVersion(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index b7875c5f99774..df448d2665434 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -42,6 +42,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -60,8 +61,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) { private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30); + private final TaskResourceTrackingService taskResourceTrackingService; + @Inject - public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) { + public TransportListTasksAction( + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + TaskResourceTrackingService taskResourceTrackingService + ) { super( ListTasksAction.NAME, clusterService, @@ -72,6 +80,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService TaskInfo::new, ThreadPool.Names.MANAGEMENT ); + this.taskResourceTrackingService = taskResourceTrackingService; } @Override @@ -101,6 +110,8 @@ protected void processTasks(ListTasksRequest request, Consumer operation) } taskManager.waitForTaskCompletion(task, timeoutNanos); }); + } else { + operation = operation.andThen(taskResourceTrackingService::refreshResourceStats); } super.processTasks(request, operation); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index 2e506c6fe181b..f09701c7769eb 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -49,6 +49,11 @@ public SearchShardTask(long id, String type, String action, String description, super(id, type, action, description, parentTaskId, headers); } + @Override + public boolean supportsResourceTracking() { + return true; + } + @Override public boolean shouldCancelChildrenOnCancellation() { return false; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index 7f80f7836be6c..bf6f141a3e829 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -78,6 +78,11 @@ public final String getDescription() { return descriptionSupplier.get(); } + @Override + public boolean supportsResourceTracking() { + return true; + } + /** * Attach a {@link SearchProgressListener} to this task. */ diff --git a/server/src/main/java/org/opensearch/action/support/TransportAction.java b/server/src/main/java/org/opensearch/action/support/TransportAction.java index 84ece8cfec530..83fca715c7e28 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportAction.java +++ b/server/src/main/java/org/opensearch/action/support/TransportAction.java @@ -40,6 +40,7 @@ import org.opensearch.action.ActionResponse; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; @@ -88,31 +89,39 @@ public final Task execute(Request request, ActionListener listener) { */ final Releasable unregisterChildNode = registerChildNode(request.getParentTask()); final Task task; + try { task = taskManager.register("transport", actionName, request); } catch (TaskCancelledException e) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(response); + + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(e); + } } - } - }); + }); + } finally { + storedContext.close(); + } + return task; } @@ -129,25 +138,30 @@ public final Task execute(Request request, TaskListener listener) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(task, response); + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(task, response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(task, e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(task, e); + } } - } - }); + }); + } finally { + storedContext.close(); + } return task; } diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index c85691b80d7c3..b9f3a2a99f0b7 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -94,6 +94,7 @@ import org.opensearch.script.ScriptMetadata; import org.opensearch.snapshots.SnapshotsInfoService; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.TaskResultsService; import java.util.ArrayList; @@ -394,6 +395,7 @@ protected void configure() { bind(NodeMappingRefreshAction.class).asEagerSingleton(); bind(MappingUpdatedAction.class).asEagerSingleton(); bind(TaskResultsService.class).asEagerSingleton(); + bind(TaskResourceTrackingService.class).asEagerSingleton(); bind(AllocationDeciders.class).toInstance(allocationDeciders); bind(ShardsAllocator.class).toInstance(shardsAllocator); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 3c5bc91b9fdb5..d21d4324b934f 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -39,6 +39,7 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction; import org.opensearch.action.admin.indices.close.TransportCloseIndexAction; @@ -603,7 +604,8 @@ public void apply(Settings value, Settings current, Settings previous) { NodeLoadAwareAllocationDecider.CLUSTER_ROUTING_ALLOCATION_LOAD_AWARENESS_PROVISIONED_CAPACITY_SETTING, NodeLoadAwareAllocationDecider.CLUSTER_ROUTING_ALLOCATION_LOAD_AWARENESS_SKEW_FACTOR_SETTING, NodeLoadAwareAllocationDecider.CLUSTER_ROUTING_ALLOCATION_LOAD_AWARENESS_ALLOW_UNASSIGNED_PRIMARIES_SETTING, - FsHealthService.HEALTHY_TIMEOUT_SETTING + FsHealthService.HEALTHY_TIMEOUT_SETTING, + TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED ) ) ); diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index 5a967528a6ae2..813a67693fd20 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -40,6 +40,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.node.Node; +import org.opensearch.threadpool.RunnableTaskExecutionListener; +import org.opensearch.threadpool.TaskAwareRunnable; import java.util.List; import java.util.Optional; @@ -55,6 +57,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; public class OpenSearchExecutors { @@ -172,6 +175,31 @@ public static OpenSearchThreadPoolExecutor newFixed( ); } + public static OpenSearchThreadPoolExecutor newAutoQueueFixed( + String name, + int size, + int initialQueueCapacity, + int minQueueSize, + int maxQueueSize, + int frameSize, + TimeValue targetedResponseTime, + ThreadFactory threadFactory, + ThreadContext contextHolder + ) { + return newAutoQueueFixed( + name, + size, + initialQueueCapacity, + minQueueSize, + maxQueueSize, + frameSize, + targetedResponseTime, + threadFactory, + contextHolder, + null + ); + } + /** * Return a new executor that will automatically adjust the queue size based on queue throughput. * @@ -180,6 +208,7 @@ public static OpenSearchThreadPoolExecutor newFixed( * @param minQueueSize minimum queue size that the queue can be adjusted to * @param maxQueueSize maximum queue size that the queue can be adjusted to * @param frameSize number of tasks during which stats are collected before adjusting queue size + * @param runnableTaskListener callback listener for a TaskAwareRunnable */ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( String name, @@ -190,17 +219,30 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( int frameSize, TimeValue targetedResponseTime, ThreadFactory threadFactory, - ThreadContext contextHolder + ThreadContext contextHolder, + AtomicReference runnableTaskListener ) { if (initialQueueCapacity <= 0) { throw new IllegalArgumentException( "initial queue capacity for [" + name + "] executor must be positive, got: " + initialQueueCapacity ); } + ResizableBlockingQueue queue = new ResizableBlockingQueue<>( ConcurrentCollections.newBlockingQueue(), initialQueueCapacity ); + + Function runnableWrapper; + if (runnableTaskListener != null) { + runnableWrapper = (runnable) -> { + TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener); + return new TimedRunnable(taskAwareRunnable); + }; + } else { + runnableWrapper = TimedRunnable::new; + } + return new QueueResizingOpenSearchThreadPoolExecutor( name, size, @@ -210,7 +252,7 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( queue, minQueueSize, maxQueueSize, - TimedRunnable::new, + runnableWrapper, frameSize, targetedResponseTime, threadFactory, diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index d844a8f158ea4..35d7d925ce106 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -66,6 +66,7 @@ import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT; import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; /** * A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with @@ -134,16 +135,23 @@ public StoredContext stashContext() { * This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user. * Otherwise when context is stash, it should be empty. */ + + ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; + if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) { - ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders( + threadContextStruct = threadContextStruct.putHeaders( MapBuilder.newMapBuilder() .put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID)) .immutableMap() ); - threadLocal.set(threadContextStruct); - } else { - threadLocal.set(DEFAULT_CONTEXT); } + + if (context.transientHeaders.containsKey(TASK_ID)) { + threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID)); + } + + threadLocal.set(threadContextStruct); + return () -> { // If the node and thus the threadLocal get closed while this task // is still executing, we don't want this runnable to fail with an diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index b8068a92443bc..f5994beb7d33a 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -37,6 +37,8 @@ import org.apache.lucene.util.Constants; import org.apache.lucene.util.SetOnce; import org.opensearch.index.IndexingPressureService; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; import org.opensearch.Build; @@ -212,6 +214,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; @@ -314,6 +317,7 @@ public class Node implements Closeable { private final Collection pluginLifecycleComponents; private final LocalNodeFactory localNodeFactory; private final NodeService nodeService; + private final AtomicReference runnableTaskListener; public Node(Environment environment) { this(environment, Collections.emptyList(), true); @@ -427,7 +431,8 @@ protected Node( final List> executorBuilders = pluginsService.getExecutorBuilders(settings); - final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0])); + runnableTaskListener = new AtomicReference<>(); + final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0])); resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS)); final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool); resourcesToClose.add(resourceWatcherService); @@ -1049,6 +1054,11 @@ public Node start() throws NodeValidationException { TransportService transportService = injector.getInstance(TransportService.class); transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); + + TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); + runnableTaskListener.set(taskResourceTrackingService); + transportService.start(); assert localNodeFactory.getNode() != null; assert transportService.getLocalNode().equals(localNodeFactory.getNode()) diff --git a/server/src/main/java/org/opensearch/rest/action/cat/RestTasksAction.java b/server/src/main/java/org/opensearch/rest/action/cat/RestTasksAction.java index b87205593ce87..a6624c2f8cfdc 100644 --- a/server/src/main/java/org/opensearch/rest/action/cat/RestTasksAction.java +++ b/server/src/main/java/org/opensearch/rest/action/cat/RestTasksAction.java @@ -137,6 +137,7 @@ protected Table getTableWithHeader(final RestRequest request) { // Task detailed info if (detailed) { table.addCell("description", "default:true;alias:desc;desc:task action"); + table.addCell("resource_stats", "default:false;desc:resource consumption info of the task"); } table.endHeaders(); return table; @@ -173,6 +174,7 @@ private void buildRow(Table table, boolean fullId, boolean detailed, DiscoveryNo if (detailed) { table.addCell(taskInfo.getDescription()); + table.addCell(taskInfo.getResourceStats()); } table.endRow(); } diff --git a/server/src/main/java/org/opensearch/tasks/ResourceStats.java b/server/src/main/java/org/opensearch/tasks/ResourceStats.java new file mode 100644 index 0000000000000..aab103ad08dcf --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ResourceStats.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +/** + * Different resource stats are defined. + */ +public enum ResourceStats { + CPU("cpu_time_in_nanos"), + MEMORY("memory_in_bytes"); + + private final String statsName; + + ResourceStats(String statsName) { + this.statsName = statsName; + } + + @Override + public String toString() { + return statsName; + } +} diff --git a/server/src/main/java/org/opensearch/tasks/ResourceStatsType.java b/server/src/main/java/org/opensearch/tasks/ResourceStatsType.java new file mode 100644 index 0000000000000..c670ac5ba689c --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ResourceStatsType.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +/** Defines the different types of resource stats. */ +public enum ResourceStatsType { + // resource stats of the worker thread reported directly from runnable. + WORKER_STATS("worker_stats", false); + + private final String statsType; + private final boolean onlyForAnalysis; + + ResourceStatsType(String statsType, boolean onlyForAnalysis) { + this.statsType = statsType; + this.onlyForAnalysis = onlyForAnalysis; + } + + public boolean isOnlyForAnalysis() { + return onlyForAnalysis; + } + + @Override + public String toString() { + return statsType; + } +} diff --git a/server/src/main/java/org/opensearch/tasks/ResourceUsageInfo.java b/server/src/main/java/org/opensearch/tasks/ResourceUsageInfo.java new file mode 100644 index 0000000000000..ae58f712b63c2 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ResourceUsageInfo.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.Collections; +import java.util.EnumMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Thread resource usage information for particular resource stats type. + *

+ * It captures the resource usage information like memory, CPU about a particular execution of thread + * for a specific stats type. + */ +public class ResourceUsageInfo { + private static final Logger logger = LogManager.getLogger(ResourceUsageInfo.class); + private final EnumMap statsInfo = new EnumMap<>(ResourceStats.class); + + public ResourceUsageInfo(ResourceUsageMetric... resourceUsageMetrics) { + for (ResourceUsageMetric resourceUsageMetric : resourceUsageMetrics) { + this.statsInfo.put(resourceUsageMetric.getStats(), new ResourceStatsInfo(resourceUsageMetric.getValue())); + } + } + + public void recordResourceUsageMetrics(ResourceUsageMetric... resourceUsageMetrics) { + for (ResourceUsageMetric resourceUsageMetric : resourceUsageMetrics) { + final ResourceStatsInfo resourceStatsInfo = statsInfo.get(resourceUsageMetric.getStats()); + if (resourceStatsInfo != null) { + updateResourceUsageInfo(resourceStatsInfo, resourceUsageMetric); + } else { + throw new IllegalStateException( + "cannot update [" + + resourceUsageMetric.getStats().toString() + + "] entry as its not present current_stats_info:" + + statsInfo + ); + } + } + } + + private void updateResourceUsageInfo(ResourceStatsInfo resourceStatsInfo, ResourceUsageMetric resourceUsageMetric) { + long currentEndValue; + long newEndValue; + do { + currentEndValue = resourceStatsInfo.endValue.get(); + newEndValue = resourceUsageMetric.getValue(); + if (currentEndValue > newEndValue) { + logger.debug( + "dropping resource usage update as the new value is lower than current value [" + + "resource_stats=[{}], " + + "current_end_value={}, " + + "new_end_value={}]", + resourceUsageMetric.getStats(), + currentEndValue, + newEndValue + ); + return; + } + } while (!resourceStatsInfo.endValue.compareAndSet(currentEndValue, newEndValue)); + logger.debug( + "updated resource usage info [resource_stats=[{}], " + "old_end_value={}, new_end_value={}]", + resourceUsageMetric.getStats(), + currentEndValue, + newEndValue + ); + } + + public Map getStatsInfo() { + return Collections.unmodifiableMap(statsInfo); + } + + @Override + public String toString() { + return statsInfo.toString(); + } + + /** + * Defines resource stats information. + */ + static class ResourceStatsInfo { + private final long startValue; + private final AtomicLong endValue; + + private ResourceStatsInfo(long startValue) { + this.startValue = startValue; + this.endValue = new AtomicLong(startValue); + } + + public long getTotalValue() { + return endValue.get() - startValue; + } + + @Override + public String toString() { + return String.valueOf(getTotalValue()); + } + } +} diff --git a/server/src/main/java/org/opensearch/tasks/ResourceUsageMetric.java b/server/src/main/java/org/opensearch/tasks/ResourceUsageMetric.java new file mode 100644 index 0000000000000..0d13ffe6ec01a --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ResourceUsageMetric.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +public class ResourceUsageMetric { + private final ResourceStats stats; + private final long value; + + public ResourceUsageMetric(ResourceStats stats, long value) { + this.stats = stats; + this.value = value; + } + + public ResourceStats getStats() { + return stats; + } + + public long getValue() { + return value; + } +} diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 8646a97da5cfe..9a8940dc61823 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -32,25 +32,39 @@ package org.opensearch.tasks; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionResponse; +import org.opensearch.action.NotifyOnceListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.ToXContentObject; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; /** * Current task information */ public class Task { + private static final Logger logger = LogManager.getLogger(Task.class); + /** * The request header to mark tasks with specific ids */ public static final String X_OPAQUE_ID = "X-Opaque-Id"; + private static final String TOTAL = "total"; + private final long id; private final String type; @@ -63,6 +77,17 @@ public class Task { private final Map headers; + private final Map> resourceStats; + + private final List> resourceTrackingCompletionListeners; + + /** + * Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track + * the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource + * tracking can be stopped for this task. + */ + private final AtomicInteger numActiveResourceTrackingThreads = new AtomicInteger(1); + /** * The task's start time as a wall clock time since epoch ({@link System#currentTimeMillis()} style). */ @@ -74,7 +99,18 @@ public class Task { private final long startTimeNanos; public Task(long id, String type, String action, String description, TaskId parentTask, Map headers) { - this(id, type, action, description, parentTask, System.currentTimeMillis(), System.nanoTime(), headers); + this( + id, + type, + action, + description, + parentTask, + System.currentTimeMillis(), + System.nanoTime(), + headers, + new ConcurrentHashMap<>(), + new ArrayList<>() + ); } public Task( @@ -85,7 +121,9 @@ public Task( TaskId parentTask, long startTime, long startTimeNanos, - Map headers + Map headers, + ConcurrentHashMap> resourceStats, + List> resourceTrackingCompletionListeners ) { this.id = id; this.type = type; @@ -95,6 +133,8 @@ public Task( this.startTime = startTime; this.startTimeNanos = startTimeNanos; this.headers = headers; + this.resourceStats = resourceStats; + this.resourceTrackingCompletionListeners = resourceTrackingCompletionListeners; } /** @@ -108,19 +148,48 @@ public Task( * generate data? */ public final TaskInfo taskInfo(String localNodeId, boolean detailed) { + return taskInfo(localNodeId, detailed, detailed == false); + } + + /** + * Build a version of the task status you can throw over the wire and back + * with the option to include resource stats or not. + * This method is only used during creating TaskResult to avoid storing resource information into the task index. + * + * @param excludeStats should information exclude resource stats. + * By default, detailed flag is used to control including resource information. + * But inorder to avoid storing resource stats into task index as strict mapping is enforced and breaks when adding this field. + * In the future, task-index-mapping.json can be modified to add resource stats. + */ + private TaskInfo taskInfo(String localNodeId, boolean detailed, boolean excludeStats) { String description = null; Task.Status status = null; + TaskResourceStats resourceStats = null; if (detailed) { description = getDescription(); status = getStatus(); } - return taskInfo(localNodeId, description, status); + if (excludeStats == false) { + resourceStats = new TaskResourceStats(new HashMap() { + { + put(TOTAL, getTotalResourceStats()); + } + }); + } + return taskInfo(localNodeId, description, status, resourceStats); } /** - * Build a proper {@link TaskInfo} for this task. + * Build a {@link TaskInfo} for this task without resource stats. */ protected final TaskInfo taskInfo(String localNodeId, String description, Status status) { + return taskInfo(localNodeId, description, status, null); + } + + /** + * Build a proper {@link TaskInfo} for this task. + */ + protected final TaskInfo taskInfo(String localNodeId, String description, Status status, TaskResourceStats resourceStats) { return new TaskInfo( new TaskId(localNodeId, getId()), getType(), @@ -131,7 +200,8 @@ protected final TaskInfo taskInfo(String localNodeId, String description, Status System.nanoTime() - startTimeNanos, this instanceof CancellableTask, parentTask, - headers + headers, + resourceStats ); } @@ -194,6 +264,115 @@ public Status getStatus() { return null; } + /** + * Returns thread level resource consumption of the task + */ + public Map> getResourceStats() { + return Collections.unmodifiableMap(resourceStats); + } + + /** + * Returns current total resource usage of the task. + * Currently, this method is only called on demand, during get and listing of tasks. + * In the future, these values can be cached as an optimization. + */ + public TaskResourceUsage getTotalResourceStats() { + return new TaskResourceUsage(getTotalResourceUtilization(ResourceStats.CPU), getTotalResourceUtilization(ResourceStats.MEMORY)); + } + + /** + * Returns total resource consumption for a specific task stat. + */ + public long getTotalResourceUtilization(ResourceStats stats) { + long totalResourceConsumption = 0L; + for (List threadResourceInfosList : resourceStats.values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfosList) { + final ResourceUsageInfo.ResourceStatsInfo statsInfo = threadResourceInfo.getResourceUsageInfo().getStatsInfo().get(stats); + if (threadResourceInfo.getStatsType().isOnlyForAnalysis() == false && statsInfo != null) { + totalResourceConsumption += statsInfo.getTotalValue(); + } + } + } + return totalResourceConsumption; + } + + /** + * Adds thread's starting resource consumption information + * @param threadId ID of the thread + * @param statsType stats type + * @param resourceUsageMetrics resource consumption metrics of the thread + * @throws IllegalStateException matching active thread entry was found which is not expected. + */ + public void startThreadResourceTracking(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + final List threadResourceInfoList = resourceStats.computeIfAbsent(threadId, k -> new ArrayList<>()); + // active thread entry should not be present in the list + for (ThreadResourceInfo threadResourceInfo : threadResourceInfoList) { + if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { + throw new IllegalStateException( + "unexpected active thread resource entry present [" + threadId + "]:[" + threadResourceInfo + "]" + ); + } + } + threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); + incrementResourceTrackingThreads(); + } + + /** + * This method is used to update the resource consumption stats so that the data isn't too stale for long-running task. + * If active thread entry is present in the list, the entry is updated. If one is not found, it throws an exception. + * @param threadId ID of the thread + * @param statsType stats type + * @param resourceUsageMetrics resource consumption metrics of the thread + * @throws IllegalStateException if no matching active thread entry was found. + */ + public void updateThreadResourceStats(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + final List threadResourceInfoList = resourceStats.get(threadId); + if (threadResourceInfoList != null) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfoList) { + // the active entry present in the list is updated + if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { + threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); + return; + } + } + } + throw new IllegalStateException("cannot update if active thread resource entry is not present"); + } + + /** + * Record the thread's final resource consumption values. + * If active thread entry is present in the list, the entry is updated. If one is not found, it throws an exception. + * @param threadId ID of the thread + * @param statsType stats type + * @param resourceUsageMetrics resource consumption metrics of the thread + * @throws IllegalStateException if no matching active thread entry was found. + */ + public void stopThreadResourceTracking(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + final List threadResourceInfoList = resourceStats.get(threadId); + if (threadResourceInfoList != null) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfoList) { + if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { + threadResourceInfo.setActive(false); + threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); + decrementResourceTrackingThreads(); + return; + } + } + } + throw new IllegalStateException("cannot update final values if active thread resource entry is not present"); + } + + /** + * Individual tasks can override this if they want to support task resource tracking. We just need to make sure that + * the ThreadPool on which the task runs on have runnable wrapper similar to + * {@link org.opensearch.common.util.concurrent.OpenSearchExecutors#newAutoQueueFixed} + * + * @return true if resource tracking is supported by the task + */ + public boolean supportsResourceTracking() { + return false; + } + /** * Report of the internal status of a task. These can vary wildly from task * to task because each task is implemented differently but we should try @@ -216,14 +395,73 @@ public String getHeader(String header) { } public TaskResult result(DiscoveryNode node, Exception error) throws IOException { - return new TaskResult(taskInfo(node.getId(), true), error); + return new TaskResult(taskInfo(node.getId(), true, true), error); } public TaskResult result(DiscoveryNode node, ActionResponse response) throws IOException { if (response instanceof ToXContent) { - return new TaskResult(taskInfo(node.getId(), true), (ToXContent) response); + return new TaskResult(taskInfo(node.getId(), true, true), (ToXContent) response); } else { throw new IllegalStateException("response has to implement ToXContent to be able to store the results"); } } + + /** + * Registers a task resource tracking completion listener on this task if resource tracking is still active. + * Returns true on successful subscription, false otherwise. + */ + public boolean addResourceTrackingCompletionListener(NotifyOnceListener listener) { + if (numActiveResourceTrackingThreads.get() > 0) { + resourceTrackingCompletionListeners.add(listener); + return true; + } + + return false; + } + + /** + * Increments the number of active resource tracking threads. + * + * @return the number of active resource tracking threads. + */ + public int incrementResourceTrackingThreads() { + return numActiveResourceTrackingThreads.incrementAndGet(); + } + + /** + * Decrements the number of active resource tracking threads. + * This method is called when threads finish execution, and also when the task is unregistered (to mark the task's + * own thread as complete). When the active thread count becomes zero, the onTaskResourceTrackingCompleted method + * is called exactly once on all registered listeners. + * + * Since a task is unregistered after the message is processed, it implies that the threads responsible to produce + * the response must have started prior to it (i.e. startThreadResourceTracking called before unregister). + * This ensures that the number of active threads doesn't drop to zero pre-maturely. + * + * Rarely, some threads may even start execution after the task is unregistered. As resource stats are piggy-backed + * with the response, any thread usage info captured after the task is unregistered may be irrelevant. + * + * @return the number of active resource tracking threads. + */ + public int decrementResourceTrackingThreads() { + int count = numActiveResourceTrackingThreads.decrementAndGet(); + + if (count == 0) { + List listenerExceptions = new ArrayList<>(); + resourceTrackingCompletionListeners.forEach(listener -> { + try { + listener.onResponse(this); + } catch (Exception e1) { + try { + listener.onFailure(e1); + } catch (Exception e2) { + listenerExceptions.add(e2); + } + } + }); + ExceptionsHelper.maybeThrowRuntimeAndSuppress(listenerExceptions); + } + + return count; + } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskInfo.java b/server/src/main/java/org/opensearch/tasks/TaskInfo.java index 03afa763efd65..36f9ebdfa401e 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskInfo.java +++ b/server/src/main/java/org/opensearch/tasks/TaskInfo.java @@ -33,6 +33,7 @@ package org.opensearch.tasks; import org.opensearch.LegacyESVersion; +import org.opensearch.Version; import org.opensearch.common.ParseField; import org.opensearch.common.Strings; import org.opensearch.common.bytes.BytesReference; @@ -84,6 +85,8 @@ public final class TaskInfo implements Writeable, ToXContentFragment { private final Map headers; + private final TaskResourceStats resourceStats; + public TaskInfo( TaskId taskId, String type, @@ -94,7 +97,8 @@ public TaskInfo( long runningTimeNanos, boolean cancellable, TaskId parentTaskId, - Map headers + Map headers, + TaskResourceStats resourceStats ) { this.taskId = taskId; this.type = type; @@ -106,11 +110,13 @@ public TaskInfo( this.cancellable = cancellable; this.parentTaskId = parentTaskId; this.headers = headers; + this.resourceStats = resourceStats; } /** * Read from a stream. */ + @SuppressWarnings("unchecked") public TaskInfo(StreamInput in) throws IOException { taskId = TaskId.readFromStream(in); type = in.readString(); @@ -126,6 +132,11 @@ public TaskInfo(StreamInput in) throws IOException { } else { headers = Collections.emptyMap(); } + if (in.getVersion().onOrAfter(Version.CURRENT)) { // TODO: Check with Sruti on the version + resourceStats = in.readOptionalWriteable(TaskResourceStats::new); + } else { + resourceStats = null; + } } @Override @@ -142,6 +153,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(LegacyESVersion.V_6_2_0)) { out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); } + if (out.getVersion().onOrAfter(Version.CURRENT)) { // TODO: Check with Sruti on the version + out.writeOptionalWriteable(resourceStats); + } } public TaskId getTaskId() { @@ -207,6 +221,13 @@ public Map getHeaders() { return headers; } + /** + * Returns the task resource information + */ + public TaskResourceStats getResourceStats() { + return resourceStats; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field("node", taskId.getNodeId()); @@ -233,6 +254,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(attribute.getKey(), attribute.getValue()); } builder.endObject(); + if (resourceStats != null) { + builder.startObject("resource_stats"); + resourceStats.toXContent(builder, params); + builder.endObject(); + } return builder; } @@ -257,9 +283,23 @@ public static TaskInfo fromXContent(XContentParser parser) { // This might happen if we are reading an old version of task info headers = Collections.emptyMap(); } + @SuppressWarnings("unchecked") + TaskResourceStats resourceStats = (TaskResourceStats) a[i++]; RawTaskStatus status = statusBytes == null ? null : new RawTaskStatus(statusBytes); TaskId parentTaskId = parentTaskIdString == null ? TaskId.EMPTY_TASK_ID : new TaskId(parentTaskIdString); - return new TaskInfo(id, type, action, description, status, startTime, runningTimeNanos, cancellable, parentTaskId, headers); + return new TaskInfo( + id, + type, + action, + description, + status, + startTime, + runningTimeNanos, + cancellable, + parentTaskId, + headers, + resourceStats + ); }); static { // Note for the future: this has to be backwards and forwards compatible with all changes to the task storage format @@ -275,6 +315,7 @@ public static TaskInfo fromXContent(XContentParser parser) { PARSER.declareBoolean(constructorArg(), new ParseField("cancellable")); PARSER.declareString(optionalConstructorArg(), new ParseField("parent_task_id")); PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapStrings(), new ParseField("headers")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> TaskResourceStats.fromXContent(p), new ParseField("resource_stats")); } @Override @@ -298,11 +339,24 @@ public boolean equals(Object obj) { && Objects.equals(parentTaskId, other.parentTaskId) && Objects.equals(cancellable, other.cancellable) && Objects.equals(status, other.status) - && Objects.equals(headers, other.headers); + && Objects.equals(headers, other.headers) + && Objects.equals(resourceStats, other.resourceStats); } @Override public int hashCode() { - return Objects.hash(taskId, type, action, description, startTime, runningTimeNanos, parentTaskId, cancellable, status, headers); + return Objects.hash( + taskId, + type, + action, + description, + startTime, + runningTimeNanos, + parentTaskId, + cancellable, + status, + headers, + resourceStats + ); } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 1f6169768f245..b0b28dbd107f8 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -44,6 +44,7 @@ import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionResponse; +import org.opensearch.action.NotifyOnceListener; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateApplier; import org.opensearch.cluster.node.DiscoveryNode; @@ -89,7 +90,9 @@ public class TaskManager implements ClusterStateApplier { private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); - /** Rest headers that are copied to the task */ + /** + * Rest headers that are copied to the task + */ private final List taskHeaders; private final ThreadPool threadPool; @@ -103,6 +106,7 @@ public class TaskManager implements ClusterStateApplier { private final Map banedParents = new ConcurrentHashMap<>(); private TaskResultsService taskResultsService; + private final SetOnce taskResourceTrackingService = new SetOnce<>(); private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; @@ -125,6 +129,10 @@ public void setTaskCancellationService(TaskCancellationService taskCancellationS this.cancellationService.set(taskCancellationService); } + public void setTaskResourceTrackingService(TaskResourceTrackingService taskResourceTrackingService) { + this.taskResourceTrackingService.set(taskResourceTrackingService); + } + /** * Registers a task without parent task */ @@ -150,6 +158,30 @@ public Task register(String type, String action, TaskAwareRequest request) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } + if (task.supportsResourceTracking()) { + boolean success = task.addResourceTrackingCompletionListener(new NotifyOnceListener() { + @Override + protected void innerOnResponse(Task task) { + // Stop tracking the task once the last thread has been marked inactive. + if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { + taskResourceTrackingService.get().stopTracking(task); + } + } + + @Override + protected void innerOnFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + }); + + if (success == false) { + logger.debug( + "failed to register a completion listener as task resource tracking has already completed [taskId={}]", + task.getId() + ); + } + } + if (task instanceof CancellableTask) { registerCancellableTask(task); } else { @@ -202,6 +234,10 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { */ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); + + // Decrement the task's self-thread as part of unregistration. + task.decrementResourceTrackingThreads(); + if (task instanceof CancellableTask) { CancellableTaskHolder holder = cancellableTasks.remove(task.getId()); if (holder != null) { @@ -448,6 +484,18 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task); } + /** + * Takes actions when a task is registered and its execution starts + * + * @param task getting executed. + * @return AutoCloseable to free up resources (clean up thread context) when task execution block returns + */ + public ThreadContext.StoredContext taskExecutionStarted(Task task) { + if (taskResourceTrackingService.get() == null) return () -> {}; + + return taskResourceTrackingService.get().startTracking(task); + } + private static class CancellableTaskHolder { private final CancellableTask task; private boolean finished = false; diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceStats.java b/server/src/main/java/org/opensearch/tasks/TaskResourceStats.java new file mode 100644 index 0000000000000..c35e08ebb34ec --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceStats.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Resource information about a currently running task. + *

+ * Writeable TaskResourceStats objects are used to represent resource + * snapshot information about currently running task. + */ +public class TaskResourceStats implements Writeable, ToXContentFragment { + private final Map resourceUsage; + + public TaskResourceStats(Map resourceUsage) { + this.resourceUsage = Objects.requireNonNull(resourceUsage, "resource usage is required"); + } + + /** + * Read from a stream. + */ + public TaskResourceStats(StreamInput in) throws IOException { + resourceUsage = in.readMap(StreamInput::readString, TaskResourceUsage::readFromStream); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(resourceUsage, StreamOutput::writeString, (stream, stats) -> stats.writeTo(stream)); + } + + public Map getResourceUsageInfo() { + return resourceUsage; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + for (Map.Entry resourceUsageEntry : resourceUsage.entrySet()) { + builder.startObject(resourceUsageEntry.getKey()); + if (resourceUsageEntry.getValue() != null) { + resourceUsageEntry.getValue().toXContent(builder, params); + } + builder.endObject(); + } + return builder; + } + + public static TaskResourceStats fromXContent(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + if (token == null) { + token = parser.nextToken(); + } + if (token == XContentParser.Token.START_OBJECT) { + token = parser.nextToken(); + } + final Map resourceStats = new HashMap<>(); + if (token == XContentParser.Token.FIELD_NAME) { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME : "Expected field name but saw [" + parser.currentToken() + "]"; + do { + // Must point to field name + String fieldName = parser.currentName(); + // And then the value + TaskResourceUsage value = TaskResourceUsage.fromXContent(parser); + resourceStats.put(fieldName, value); + } while (parser.nextToken() == XContentParser.Token.FIELD_NAME); + } + return new TaskResourceStats(resourceStats); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } + + // Implements equals and hashcode for testing + @Override + public boolean equals(Object obj) { + if (obj == null || obj.getClass() != TaskResourceStats.class) { + return false; + } + TaskResourceStats other = (TaskResourceStats) obj; + return Objects.equals(resourceUsage, other.resourceUsage); + } + + @Override + public int hashCode() { + return Objects.hash(resourceUsage); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java new file mode 100644 index 0000000000000..c3cad117390e4 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -0,0 +1,248 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import com.sun.management.ThreadMXBean; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ConcurrentCollections; +import org.opensearch.common.util.concurrent.ConcurrentMapLong; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.threadpool.RunnableTaskExecutionListener; +import org.opensearch.threadpool.ThreadPool; + +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; + +/** + * Service that helps track resource usage of tasks running on a node. + */ +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") +public class TaskResourceTrackingService implements RunnableTaskExecutionListener { + + private static final Logger logger = LogManager.getLogger(TaskManager.class); + + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( + "task_resource_tracking.enabled", + true, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + public static final String TASK_ID = "TASK_ID"; + + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final ThreadPool threadPool; + private volatile boolean taskResourceTrackingEnabled; + + @Inject + public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) { + this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); + this.threadPool = threadPool; + clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); + } + + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { + this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + } + + public boolean isTaskResourceTrackingEnabled() { + return taskResourceTrackingEnabled; + } + + public boolean isTaskResourceTrackingSupported() { + return threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled(); + } + + /** + * Executes logic only if task supports resource tracking and resource tracking setting is enabled. + *

+ * 1. Starts tracking the task in map of resourceAwareTasks. + * 2. Adds Task Id in thread context to make sure it's available while task is processed across multiple threads. + * + * @param task for which resources needs to be tracked + * @return Autocloseable stored context to restore ThreadContext to the state before this method changed it. + */ + public ThreadContext.StoredContext startTracking(Task task) { + if (task.supportsResourceTracking() == false + || isTaskResourceTrackingEnabled() == false + || isTaskResourceTrackingSupported() == false) { + return () -> {}; + } + + logger.debug("Starting resource tracking for task: {}", task.getId()); + resourceAwareTasks.put(task.getId(), task); + return addTaskIdToThreadContext(task); + } + + /** + * Stops tracking task registered earlier for tracking. + *

+ * It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress. + *

+ * It's also responsible to stop tracking the current thread's resources against this task if not already done. + * This happens when the thread executing the request logic itself calls the unregister method. So in this case unregister + * happens before runnable finishes. + * + * @param task task which has finished and doesn't need resource tracking. + */ + public void stopTracking(Task task) { + logger.debug("Stopping resource tracking for task: {}", task.getId()); + try { + if (isCurrentThreadWorkingOnTask(task)) { + taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); + } + } catch (Exception e) { + logger.warn("Failed while trying to mark the task execution on current thread completed.", e); + assert false; + } finally { + resourceAwareTasks.remove(task.getId()); + } + } + + /** + * Refreshes the resource stats for the tasks provided by looking into which threads are actively working on these + * and how much resources these have consumed till now. + * + * @param tasks for which resource stats needs to be refreshed. + */ + public void refreshResourceStats(Task... tasks) { + if (isTaskResourceTrackingEnabled() == false || isTaskResourceTrackingSupported() == false) { + return; + } + + for (Task task : tasks) { + if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) { + refreshResourceStats(task); + } + } + } + + private void refreshResourceStats(Task resourceAwareTask) { + try { + logger.debug("Refreshing resource stats for Task: {}", resourceAwareTask.getId()); + List threadsWorkingOnTask = getThreadsWorkingOnTask(resourceAwareTask); + threadsWorkingOnTask.forEach( + threadId -> resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)) + ); + } catch (IllegalStateException e) { + logger.debug("Resource stats already updated."); + } + + } + + /** + * Called when a thread starts working on a task's runnable. + * + * @param taskId of the task for which runnable is starting + * @param threadId of the thread which will be executing the runnable and we need to check resource usage for this + * thread + */ + @Override + public void taskExecutionStartedOnThread(long taskId, long threadId) { + try { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { + logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId); + task.startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e); + assert false; + } + + } + + /** + * Called when a thread finishes working on a task's runnable. + * + * @param taskId of the task for which runnable is complete + * @param threadId of the thread which executed the runnable and we need to check resource usage for this thread + */ + @Override + public void taskExecutionFinishedOnThread(long taskId, long threadId) { + try { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { + logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId); + task.stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e); + assert false; + } + } + + public Map getResourceAwareTasks() { + return Collections.unmodifiableMap(resourceAwareTasks); + } + + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( + ResourceStats.MEMORY, + threadMXBean.getThreadAllocatedBytes(threadId) + ); + ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); + return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; + } + + private boolean isCurrentThreadWorkingOnTask(Task task) { + long threadId = Thread.currentThread().getId(); + List threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList()); + + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + return true; + } + } + return false; + } + + private List getThreadsWorkingOnTask(Task task) { + List activeThreads = new ArrayList<>(); + for (List threadResourceInfos : task.getResourceStats().values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + activeThreads.add(threadResourceInfo.getThreadId()); + } + } + } + return activeThreads; + } + + /** + * Adds Task Id in the ThreadContext. + *

+ * Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext + * as well. + * + * @param task for which Task Id needs to be added in ThreadContext. + * @return StoredContext reference to restore the ThreadContext from which we created a new one. + * Caller can call context.restore() to get the existing ThreadContext back. + */ + private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { + ThreadContext threadContext = threadPool.getThreadContext(); + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); + threadContext.putTransient(TASK_ID, task.getId()); + return storedContext; + } + +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceUsage.java b/server/src/main/java/org/opensearch/tasks/TaskResourceUsage.java new file mode 100644 index 0000000000000..6af3de2b78c06 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceUsage.java @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.ParseField; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ConstructingObjectParser; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.opensearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * Task resource usage information + *

+ * Writeable TaskResourceUsage objects are used to represent resource usage + * information of running tasks. + */ +public class TaskResourceUsage implements Writeable, ToXContentFragment { + private static final ParseField CPU_TIME_IN_NANOS = new ParseField("cpu_time_in_nanos"); + private static final ParseField MEMORY_IN_BYTES = new ParseField("memory_in_bytes"); + + private final long cpuTimeInNanos; + private final long memoryInBytes; + + public TaskResourceUsage(long cpuTimeInNanos, long memoryInBytes) { + this.cpuTimeInNanos = cpuTimeInNanos; + this.memoryInBytes = memoryInBytes; + } + + /** + * Read from a stream. + */ + public static TaskResourceUsage readFromStream(StreamInput in) throws IOException { + return new TaskResourceUsage(in.readVLong(), in.readVLong()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(cpuTimeInNanos); + out.writeVLong(memoryInBytes); + } + + public long getCpuTimeInNanos() { + return cpuTimeInNanos; + } + + public long getMemoryInBytes() { + return memoryInBytes; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(CPU_TIME_IN_NANOS.getPreferredName(), cpuTimeInNanos); + builder.field(MEMORY_IN_BYTES.getPreferredName(), memoryInBytes); + return builder; + } + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "task_resource_usage", + a -> new TaskResourceUsage((Long) a[0], (Long) a[1]) + ); + + static { + PARSER.declareLong(constructorArg(), CPU_TIME_IN_NANOS); + PARSER.declareLong(constructorArg(), MEMORY_IN_BYTES); + } + + public static TaskResourceUsage fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } + + // Implements equals and hashcode for testing + @Override + public boolean equals(Object obj) { + if (obj == null || obj.getClass() != TaskResourceUsage.class) { + return false; + } + TaskResourceUsage other = (TaskResourceUsage) obj; + return Objects.equals(cpuTimeInNanos, other.cpuTimeInNanos) && Objects.equals(memoryInBytes, other.memoryInBytes); + } + + @Override + public int hashCode() { + return Objects.hash(cpuTimeInNanos, memoryInBytes); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java new file mode 100644 index 0000000000000..9ee683e3928f6 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +/** + * Resource consumption information about a particular execution of thread. + *

+ * It captures the resource usage information about a particular execution of thread + * for a specific stats type like worker_stats or response_stats etc., + */ +public class ThreadResourceInfo { + private final long threadId; + private volatile boolean isActive = true; + private final ResourceStatsType statsType; + private final ResourceUsageInfo resourceUsageInfo; + + public ThreadResourceInfo(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + this.threadId = threadId; + this.statsType = statsType; + this.resourceUsageInfo = new ResourceUsageInfo(resourceUsageMetrics); + } + + /** + * Updates thread's resource consumption information. + */ + public void recordResourceUsageMetrics(ResourceUsageMetric... resourceUsageMetrics) { + resourceUsageInfo.recordResourceUsageMetrics(resourceUsageMetrics); + } + + public void setActive(boolean isActive) { + this.isActive = isActive; + } + + public boolean isActive() { + return isActive; + } + + public ResourceStatsType getStatsType() { + return statsType; + } + + public long getThreadId() { + return threadId; + } + + public ResourceUsageInfo getResourceUsageInfo() { + return resourceUsageInfo; + } + + @Override + public String toString() { + return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive + ", threadId=" + threadId; + } +} diff --git a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java index 2bac5eba9fc28..55b92c5d8bfcb 100644 --- a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java +++ b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java @@ -48,6 +48,7 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicReference; /** * A builder for executors that automatically adjust the queue length as needed, depending on @@ -61,6 +62,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder maxQueueSizeSetting; private final Setting targetedResponseTimeSetting; private final Setting frameSizeSetting; + private final AtomicReference runnableTaskListener; AutoQueueAdjustingExecutorBuilder( final Settings settings, @@ -70,6 +72,19 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder runnableTaskListener ) { super(name); final String prefix = "thread_pool." + name; @@ -184,6 +199,7 @@ public Iterator> settings() { Setting.Property.Deprecated, Setting.Property.Deprecated ); + this.runnableTaskListener = runnableTaskListener; } @Override @@ -230,7 +246,8 @@ ThreadPool.ExecutorHolder build(final AutoExecutorSettings settings, final Threa frameSize, targetedResponseTime, threadFactory, - threadContext + threadContext, + runnableTaskListener ); // TODO: in a subsequent change we hope to extend ThreadPool.Info to be more specific for the thread pool type final ThreadPool.Info info = new ThreadPool.Info( diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java new file mode 100644 index 0000000000000..03cd66f80d044 --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +/** + * Listener for events when a runnable execution starts or finishes on a thread and is aware of the task for which the + * runnable is associated to. + */ +public interface RunnableTaskExecutionListener { + + /** + * Sends an update when ever a task's execution start on a thread + * + * @param taskId of task which has started + * @param threadId of thread which is executing the task + */ + void taskExecutionStartedOnThread(long taskId, long threadId); + + /** + * + * Sends an update when task execution finishes on a thread + * + * @param taskId of task which has finished + * @param threadId of thread which executed the task + */ + void taskExecutionFinishedOnThread(long taskId, long threadId); +} diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java new file mode 100644 index 0000000000000..183b9b2f4cf9a --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.WrappedRunnable; +import org.opensearch.tasks.TaskManager; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; + +import static java.lang.Thread.currentThread; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +/** + * Responsible for wrapping the original task's runnable and sending updates on when it starts and finishes to + * entities listening to the events. + * + * It's able to associate runnable with a task with the help of task Id available in thread context. + */ +public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable { + + private static final Logger logger = LogManager.getLogger(TaskManager.class); + + private final Runnable original; + private final ThreadContext threadContext; + private final AtomicReference runnableTaskListener; + + public TaskAwareRunnable( + final ThreadContext threadContext, + final Runnable original, + final AtomicReference runnableTaskListener + ) { + this.original = original; + this.threadContext = threadContext; + this.runnableTaskListener = runnableTaskListener; + } + + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + public boolean isForceExecution() { + return original instanceof AbstractRunnable && ((AbstractRunnable) original).isForceExecution(); + } + + @Override + public void onRejection(final Exception e) { + if (original instanceof AbstractRunnable) { + ((AbstractRunnable) original).onRejection(e); + } else { + ExceptionsHelper.reThrowIfNotNull(e); + } + } + + @Override + protected void doRun() throws Exception { + assert runnableTaskListener.get() != null : "Listener should be attached"; + Long taskId = threadContext.getTransient(TASK_ID); + if (Objects.nonNull(taskId)) { + runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId()); + } else { + logger.debug("Task Id not available in thread context. Skipping update. Thread Info: {}", Thread.currentThread()); + } + try { + original.run(); + } finally { + if (Objects.nonNull(taskId)) { + runnableTaskListener.get().taskExecutionFinishedOnThread(taskId, currentThread().getId()); + } + } + } + + @Override + public Runnable unwrap() { + return original; + } +} diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 7136932318317..5545cdbaa8720 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -69,6 +69,7 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static java.util.Collections.unmodifiableMap; @@ -185,6 +186,14 @@ public Collection builders() { ); public ThreadPool(final Settings settings, final ExecutorBuilder... customBuilders) { + this(settings, null, customBuilders); + } + + public ThreadPool( + final Settings settings, + final AtomicReference runnableTaskListener, + final ExecutorBuilder... customBuilders + ) { assert Node.NODE_NAME_SETTING.exists(settings); final Map builders = new HashMap<>(); @@ -198,11 +207,20 @@ public ThreadPool(final Settings settings, final ExecutorBuilder... customBui builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16)); builders.put( Names.SEARCH, - new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000, 1000, 1000, 2000) + new AutoQueueAdjustingExecutorBuilder( + settings, + Names.SEARCH, + searchThreadPoolSize(allocatedProcessors), + 1000, + 1000, + 1000, + 2000, + runnableTaskListener + ) ); builders.put( Names.SEARCH_THROTTLED, - new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200) + new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200, runnableTaskListener) ); builders.put(Names.MANAGEMENT, new ScalingExecutorBuilder(Names.MANAGEMENT, 1, 5, TimeValue.timeValueMinutes(5))); // no queue as this means clients will need to handle rejections on listener queue even if the operation succeeded diff --git a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java index 017aac7ae7e8d..4f37b4bfb2860 100644 --- a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java @@ -36,6 +36,7 @@ import org.opensearch.common.io.stream.Writeable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; @@ -80,6 +81,8 @@ public Request newRequest(StreamInput in) throws IOException { public void processMessageReceived(Request request, TransportChannel channel) throws Exception { final Task task = taskManager.register(channel.getChannelType(), action, request); + ThreadContext.StoredContext contextToRestore = taskManager.taskExecutionStarted(task); + Releasable unregisterTask = () -> taskManager.unregister(task); try { if (channel instanceof TcpTransportChannel && task instanceof CancellableTask) { @@ -92,6 +95,7 @@ public void processMessageReceived(Request request, TransportChannel channel) th unregisterTask = null; } finally { Releasables.close(unregisterTask); + contextToRestore.restore(); } } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java index 7756eb12bb3f4..9bd44185baf24 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java @@ -75,6 +75,9 @@ public synchronized void onTaskUnregistered(Task task) { @Override public void waitForTaskCompletion(Task task) {} + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} + public synchronized List> getEvents() { return Collections.unmodifiableList(new ArrayList<>(events)); } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java new file mode 100644 index 0000000000000..f21519a5f4b40 --- /dev/null +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -0,0 +1,653 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.cluster.node.tasks; + +import com.sun.management.ThreadMXBean; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionListener; +import org.opensearch.action.NotifyOnceListener; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.opensearch.action.support.ActionTestUtils; +import org.opensearch.action.support.nodes.BaseNodeRequest; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancelledException; +import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskInfo; +import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.test.tasks.MockTaskManagerListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") +public class ResourceAwareTasksTests extends TaskManagerTestCase { + + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + public static class ResourceAwareNodeRequest extends BaseNodeRequest { + protected String requestName; + + public ResourceAwareNodeRequest() { + super(); + } + + public ResourceAwareNodeRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public ResourceAwareNodeRequest(NodesRequest request) { + requestName = request.requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "ResourceAwareNodeRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } + + @Override + public boolean supportsResourceTracking() { + return true; + } + }; + } + } + + public static class NodesRequest extends BaseNodesRequest { + private final String requestName; + + private NodesRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public NodesRequest(String requestName, String... nodesIds) { + super(nodesIds); + this.requestName = requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "NodesRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + }; + } + } + + /** + * Simulates a task which executes work on search executor. + */ + class ResourceAwareNodesAction extends AbstractTestNodesAction { + private final TaskTestContext taskTestContext; + private final boolean blockForCancellation; + + ResourceAwareNodesAction( + String actionName, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + boolean shouldBlock, + TaskTestContext taskTestContext + ) { + super(actionName, threadPool, clusterService, transportService, NodesRequest::new, ResourceAwareNodeRequest::new); + this.taskTestContext = taskTestContext; + this.blockForCancellation = shouldBlock; + } + + @Override + protected ResourceAwareNodeRequest newNodeRequest(NodesRequest request) { + return new ResourceAwareNodeRequest(request); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request, Task task) { + assert task.supportsResourceTracking(); + + AtomicLong threadId = new AtomicLong(); + Future result = threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + @SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") + protected void doRun() { + taskTestContext.memoryConsumptionWhenExecutionStarts = threadMXBean.getThreadAllocatedBytes( + Thread.currentThread().getId() + ); + threadId.set(Thread.currentThread().getId()); + + // operationStartValidator will be called just before the task execution. + if (taskTestContext.operationStartValidator != null) { + taskTestContext.operationStartValidator.accept(task, threadId.get()); + } + + // operationFinishedValidator will be called just after all task threads are marked inactive and + // the task is unregistered. + if (taskTestContext.operationFinishedValidator != null) { + boolean success = task.addResourceTrackingCompletionListener(new NotifyOnceListener() { + @Override + protected void innerOnResponse(Task task) { + taskTestContext.operationFinishedValidator.accept(task, threadId.get()); + } + + @Override + protected void innerOnFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + }); + + if (success == false) { + fail("failed to register a completion listener as task resource tracking has already completed"); + } + } + + Object[] allocation1 = new Object[1000000]; // 4MB + + if (blockForCancellation) { + // Simulate a job that takes forever to finish + // Using periodic checks method to identify that the task was cancelled + try { + boolean taskCancelled = waitUntil(((CancellableTask) task)::isCancelled); + if (taskCancelled) { + throw new TaskCancelledException("Task Cancelled"); + } else { + fail("It should have thrown an exception"); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + + } + + Object[] allocation2 = new Object[1000000]; // 4MB + } + }); + + try { + result.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e.getCause()); + } + + return new NodeResponse(clusterService.localNode()); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request) { + throw new UnsupportedOperationException("the task parameter is required"); + } + } + + private TaskTestContext startResourceAwareNodesAction( + TestNode node, + boolean blockForCancellation, + TaskTestContext taskTestContext, + ActionListener listener + ) { + NodesRequest request = new NodesRequest("Test Request", node.getNodeId()); + + taskTestContext.requestCompleteLatch = new CountDownLatch(1); + + ResourceAwareNodesAction action = new ResourceAwareNodesAction( + "internal:resourceAction", + threadPool, + node.clusterService, + node.transportService, + blockForCancellation, + taskTestContext + ); + taskTestContext.mainTask = action.execute(request, listener); + return taskTestContext; + } + + private static class TaskTestContext { + private Task mainTask; + private CountDownLatch requestCompleteLatch; + private BiConsumer operationStartValidator; + private BiConsumer operationFinishedValidator; + private long memoryConsumptionWhenExecutionStarts; + } + + public void testBasicTaskResourceTracking() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { + // Thread has finished working on the task's runnable + assertEquals(0, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4000000; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedArrayAllocationOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { + // Thread has finished working on the task's runnable + assertEquals(0, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + // allocations are completed before the task is cancelled + long expectedArrayAllocationOverhead = 4000000; // Task's memory overhead due to array allocations + long taskCancellationOverhead = 30000; // Task cancellation overhead ~ 30Kb + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + long expectedOverhead = expectedArrayAllocationOverhead + taskCancellationOverhead; + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], true, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Cancel main task + CancelTasksRequest request = new CancelTasksRequest(); + request.setReason("Cancelling request to verify Task resource tracking behaviour"); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), taskTestContext.mainTask.getId())); + ActionTestUtils.executeBlocking(testNodes[0].transportCancelTasksAction, request); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertEquals(0, resourceTasks.size()); + assertNull(throwableReference.get()); + assertNotNull(responseReference.get()); + assertEquals(1, responseReference.get().failureCount()); + assertEquals(TaskCancelledException.class, findActualException(responseReference.get().failures().get(0)).getClass()); + } + + public void testTaskResourceTrackingDisabled() throws Exception { + setup(false, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(false); + }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { + // Thread has finished working on the task's runnable + assertEquals(0, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4000000; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedArrayAllocationOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception { + setup(false, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { + assertEquals(0, resourceTasks.size()); + + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); + } + + public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = (task, threadId) -> { + ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( + testNodes[0].transportListTasksAction, + new ListTasksRequest().setActions("internal:resourceAction*").setDetailed(true) + ); + + TaskInfo taskInfo = listTasksResponse.getTasks().get(1); + + assertNotNull(taskInfo.getResourceStats()); + assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo()); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getCpuTimeInNanos() > 0); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getMemoryInBytes() > 0); + }; + + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); + } + + public void testTaskIdPersistsInThreadContext() throws InterruptedException { + setup(true, true); + + final List taskIdsAddedToThreadContext = new ArrayList<>(); + final List taskIdsRemovedFromThreadContext = new ArrayList<>(); + AtomicLong actualTaskIdInThreadContext = new AtomicLong(-1); + AtomicLong expectedTaskIdInThreadContext = new AtomicLong(-2); + + ((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() { + @Override + public void waitForTaskCompletion(Task task) {} + + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) { + if (closeableInvoked) { + taskIdsRemovedFromThreadContext.add(task.getId()); + } else { + taskIdsAddedToThreadContext.add(task.getId()); + } + } + + @Override + public void onTaskRegistered(Task task) {} + + @Override + public void onTaskUnregistered(Task task) { + if (task.getAction().equals("internal:resourceAction[n]")) { + expectedTaskIdInThreadContext.set(task.getId()); + actualTaskIdInThreadContext.set(threadPool.getThreadContext().getTransient(TASK_ID)); + } + } + }); + + TaskTestContext taskTestContext = new TaskTestContext(); + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + taskTestContext.requestCompleteLatch.await(); + + assertEquals(expectedTaskIdInThreadContext.get(), actualTaskIdInThreadContext.get()); + assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); + } + + private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) { + Settings settings = Settings.builder() + .put("task_resource_tracking.enabled", resourceTrackingEnabled) + .put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), useMockTaskManager) + .build(); + setupTestNodes(settings); + connectNodes(testNodes[0]); + + runnableTaskListener.set(testNodes[0].taskResourceTrackingService); + } + + private Throwable findActualException(Exception e) { + Throwable throwable = e.getCause(); + while (throwable.getCause() != null) { + throwable = throwable.getCause(); + } + return throwable; + } + + private void assertTasksRequestFinishedSuccessfully(NodesResponse nodesResponse, Throwable throwable) { + assertNull(throwable); + assertNotNull(nodesResponse); + assertEquals(0, nodesResponse.failureCount()); + } + + private void assertMemoryUsageWithinLimits(long actual, long expected) { + // 5% buffer up to 200 KB to account for classloading overhead. + long maxOverhead = Math.min(200000, expected * 5 / 100); + assertThat(actual, lessThanOrEqualTo(expected + maxOverhead)); + } +} diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index c8411b31e0709..51fc5d80f2de3 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -59,8 +59,10 @@ import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskCancellationService; import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -74,6 +76,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static java.util.Collections.emptyMap; @@ -89,10 +92,12 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase { protected ThreadPool threadPool; protected TestNode[] testNodes; protected int nodesCount; + protected AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new AtomicReference<>(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } public void setupTestNodes(Settings settings) { @@ -225,14 +230,22 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool transportService.start(); clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); + taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings(), threadPool); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); ActionFilters actionFilters = new ActionFilters(emptySet()); - transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); + transportListTasksAction = new TransportListTasksAction( + clusterService, + transportService, + actionFilters, + taskResourceTrackingService + ); transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters); transportService.acceptIncomingRequests(); } public final ClusterService clusterService; public final TransportService transportService; + public final TaskResourceTrackingService taskResourceTrackingService; private final SetOnce discoveryNode = new SetOnce<>(); public final TransportListTasksAction transportListTasksAction; public final TransportCancelTasksAction transportCancelTasksAction; diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java index a2f8f0a5f7a44..a2588aa3cb400 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java @@ -31,16 +31,23 @@ package org.opensearch.action.admin.cluster.node.tasks; +import org.opensearch.action.search.SearchAction; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.ResourceUsageMetric; +import org.opensearch.tasks.ResourceStats; +import org.opensearch.tasks.ResourceStatsType; import org.opensearch.test.OpenSearchTestCase; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Map; +import static org.opensearch.tasks.TaskInfoTests.randomResourceStats; + public class TaskTests extends OpenSearchTestCase { public void testTaskInfoToString() { @@ -59,7 +66,8 @@ public void testTaskInfoToString() { runningTime, cancellable, TaskId.EMPTY_TASK_ID, - Collections.singletonMap("foo", "bar") + Collections.singletonMap("foo", "bar"), + randomResourceStats(randomBoolean()) ); String taskInfoString = taskInfo.toString(); Map map = XContentHelper.convertToMap(new BytesArray(taskInfoString.getBytes(StandardCharsets.UTF_8)), true).v2(); @@ -73,4 +81,68 @@ public void testTaskInfoToString() { assertEquals(map.get("headers"), Collections.singletonMap("foo", "bar")); } + public void testTaskResourceStats() { + final Task task = new Task( + randomLong(), + "transport", + SearchAction.NAME, + "description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + + long totalMemory = 0L; + long totalCPU = 0L; + + // reporting resource consumption events and checking total consumption values + for (int i = 0; i < randomInt(10); i++) { + long initial_memory = randomLongBetween(1, 100); + long initial_cpu = randomLongBetween(1, 100); + + ResourceUsageMetric[] initialTaskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, initial_memory), + new ResourceUsageMetric(ResourceStats.CPU, initial_cpu) }; + task.startThreadResourceTracking(i, ResourceStatsType.WORKER_STATS, initialTaskResourceMetrics); + + long memory = initial_memory + randomLongBetween(1, 10000); + long cpu = initial_cpu + randomLongBetween(1, 10000); + + totalMemory += memory - initial_memory; + totalCPU += cpu - initial_cpu; + + ResourceUsageMetric[] taskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, memory), + new ResourceUsageMetric(ResourceStats.CPU, cpu) }; + task.updateThreadResourceStats(i, ResourceStatsType.WORKER_STATS, taskResourceMetrics); + task.stopThreadResourceTracking(i, ResourceStatsType.WORKER_STATS); + } + assertEquals(task.getTotalResourceStats().getMemoryInBytes(), totalMemory); + assertEquals(task.getTotalResourceStats().getCpuTimeInNanos(), totalCPU); + + // updating should throw an IllegalStateException when active entry is not present. + try { + task.updateThreadResourceStats(randomInt(), ResourceStatsType.WORKER_STATS); + fail("update should not be successful as active entry is not present!"); + } catch (IllegalStateException e) { + // pass + } + + // re-adding a thread entry that is already present, should throw an exception + int threadId = randomInt(); + task.startThreadResourceTracking(threadId, ResourceStatsType.WORKER_STATS, new ResourceUsageMetric(ResourceStats.MEMORY, 100)); + try { + task.startThreadResourceTracking(threadId, ResourceStatsType.WORKER_STATS); + fail("add/start should not be successful as active entry is already present!"); + } catch (IllegalStateException e) { + // pass + } + + // existing active entry is present only for memory, update cannot be called with cpu values. + try { + task.updateThreadResourceStats(threadId, ResourceStatsType.WORKER_STATS, new ResourceUsageMetric(ResourceStats.CPU, 200)); + fail("update should not be successful as entry for CPU is not present!"); + } catch (IllegalStateException e) { + // pass + } + } } diff --git a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java index d98832e73cc88..3b3f0960023d6 100644 --- a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java @@ -91,6 +91,7 @@ import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.Answers.RETURNS_MOCKS; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; @@ -224,7 +225,7 @@ public void setupAction() { remoteResponseHandler = ArgumentCaptor.forClass(TransportResponseHandler.class); // setup services that will be called by action - transportService = mock(TransportService.class); + transportService = mock(TransportService.class, RETURNS_MOCKS); clusterService = mock(ClusterService.class); localIngest = true; // setup nodes for local and remote diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 9c70accaca3e4..64286e47b4966 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class ThreadContextTests extends OpenSearchTestCase { @@ -154,6 +155,15 @@ public void testNewContextWithClearedTransients() { assertEquals(1, threadContext.getResponseHeaders().get("baz").size()); } + public void testStashContextWithPreservedTransients() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + threadContext.putTransient("foo", "bar"); + threadContext.putTransient(TASK_ID, 1); + threadContext.stashContext(); + assertNull(threadContext.getTransient("foo")); + assertEquals(1, (int) threadContext.getTransient(TASK_ID)); + } + public void testStashWithOrigin() { final String origin = randomAlphaOfLengthBetween(4, 16); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 9496054e80ad8..444582b0d79cf 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -198,6 +198,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.FetchPhase; import org.opensearch.snapshots.mockstore.MockEventuallyConsistentRepository; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.disruption.DisruptableMockTransport; import org.opensearch.threadpool.ThreadPool; @@ -1738,6 +1739,8 @@ public void onFailure(final Exception e) { final IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver( new ThreadContext(Settings.EMPTY) ); + transportService.getTaskManager() + .setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings, threadPool)); repositoriesService = new RepositoriesService( settings, clusterService, diff --git a/server/src/test/java/org/opensearch/tasks/CancelTasksResponseTests.java b/server/src/test/java/org/opensearch/tasks/CancelTasksResponseTests.java index 64d2979c2c5a0..c0ec4ca3d31fd 100644 --- a/server/src/test/java/org/opensearch/tasks/CancelTasksResponseTests.java +++ b/server/src/test/java/org/opensearch/tasks/CancelTasksResponseTests.java @@ -62,7 +62,7 @@ protected CancelTasksResponse createTestInstance() { private static List randomTasks() { List randomTasks = new ArrayList<>(); for (int i = 0; i < randomInt(10); i++) { - randomTasks.add(TaskInfoTests.randomTaskInfo()); + randomTasks.add(TaskInfoTests.randomTaskInfo(false)); } return randomTasks; } diff --git a/server/src/test/java/org/opensearch/tasks/ListTasksResponseTests.java b/server/src/test/java/org/opensearch/tasks/ListTasksResponseTests.java index 450dd522ca891..a89d4d2ba7403 100644 --- a/server/src/test/java/org/opensearch/tasks/ListTasksResponseTests.java +++ b/server/src/test/java/org/opensearch/tasks/ListTasksResponseTests.java @@ -45,6 +45,7 @@ import java.net.ConnectException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.function.Predicate; import java.util.function.Supplier; @@ -71,7 +72,12 @@ public void testNonEmptyToString() { 1, true, new TaskId("node1", 0), - Collections.singletonMap("foo", "bar") + Collections.singletonMap("foo", "bar"), + new TaskResourceStats(new HashMap() { + { + put("dummy-type1", new TaskResourceUsage(100, 100)); + } + }) ); ListTasksResponse tasksResponse = new ListTasksResponse(singletonList(info), emptyList(), emptyList()); assertEquals( @@ -91,6 +97,12 @@ public void testNonEmptyToString() { + " \"parent_task_id\" : \"node1:0\",\n" + " \"headers\" : {\n" + " \"foo\" : \"bar\"\n" + + " },\n" + + " \"resource_stats\" : {\n" + + " \"dummy-type1\" : {\n" + + " \"cpu_time_in_nanos\" : 100,\n" + + " \"memory_in_bytes\" : 100\n" + + " }\n" + " }\n" + " }\n" + " ]\n" @@ -125,8 +137,8 @@ protected boolean supportsUnknownFields() { @Override protected Predicate getRandomFieldsExcludeFilter() { - // status and headers hold arbitrary content, we can't inject random fields in them - return field -> field.endsWith("status") || field.endsWith("headers"); + // status, headers and resource_stats hold arbitrary content, we can't inject random fields in them + return field -> field.endsWith("status") || field.endsWith("headers") || field.contains("resource_stats"); } @Override diff --git a/server/src/test/java/org/opensearch/tasks/TaskInfoTests.java b/server/src/test/java/org/opensearch/tasks/TaskInfoTests.java index b9a0d05149bb8..e28588f2255d3 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskInfoTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskInfoTests.java @@ -77,13 +77,13 @@ protected boolean supportsUnknownFields() { @Override protected Predicate getRandomFieldsExcludeFilter() { - // status and headers hold arbitrary content, we can't inject random fields in them - return field -> "status".equals(field) || "headers".equals(field); + // status, headers and resource_stats hold arbitrary content, we can't inject random fields in them + return field -> "status".equals(field) || "headers".equals(field) || field.contains("resource_stats"); } @Override protected TaskInfo mutateInstance(TaskInfo info) { - switch (between(0, 9)) { + switch (between(0, 10)) { case 0: TaskId taskId = new TaskId(info.getTaskId().getNodeId() + randomAlphaOfLength(5), info.getTaskId().getId()); return new TaskInfo( @@ -96,7 +96,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 1: return new TaskInfo( @@ -109,7 +110,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 2: return new TaskInfo( @@ -122,7 +124,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 3: return new TaskInfo( @@ -135,7 +138,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 4: Task.Status newStatus = randomValueOtherThan(info.getStatus(), TaskInfoTests::randomRawTaskStatus); @@ -149,7 +153,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 5: return new TaskInfo( @@ -162,7 +167,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 6: return new TaskInfo( @@ -175,7 +181,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos() + between(1, 100), info.isCancellable(), info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 7: return new TaskInfo( @@ -188,7 +195,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable() == false, info.getParentTaskId(), - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 8: TaskId parentId = new TaskId(info.getParentTaskId().getNodeId() + randomAlphaOfLength(5), info.getParentTaskId().getId()); @@ -202,7 +210,8 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), parentId, - info.getHeaders() + info.getHeaders(), + info.getResourceStats() ); case 9: Map headers = info.getHeaders(); @@ -222,7 +231,29 @@ protected TaskInfo mutateInstance(TaskInfo info) { info.getRunningTimeNanos(), info.isCancellable(), info.getParentTaskId(), - headers + headers, + info.getResourceStats() + ); + case 10: + Map resourceUsageMap; + if (info.getResourceStats() == null) { + resourceUsageMap = new HashMap<>(1); + } else { + resourceUsageMap = new HashMap<>(info.getResourceStats().getResourceUsageInfo()); + } + resourceUsageMap.put(randomAlphaOfLength(5), new TaskResourceUsage(randomNonNegativeLong(), randomNonNegativeLong())); + return new TaskInfo( + info.getTaskId(), + info.getType(), + info.getAction(), + info.getDescription(), + info.getStatus(), + info.getStartTime(), + info.getRunningTimeNanos(), + info.isCancellable(), + info.getParentTaskId(), + info.getHeaders(), + new TaskResourceStats(resourceUsageMap) ); default: throw new IllegalStateException(); @@ -230,11 +261,15 @@ protected TaskInfo mutateInstance(TaskInfo info) { } static TaskInfo randomTaskInfo() { + return randomTaskInfo(randomBoolean()); + } + + static TaskInfo randomTaskInfo(boolean detailed) { TaskId taskId = randomTaskId(); String type = randomAlphaOfLength(5); String action = randomAlphaOfLength(5); - Task.Status status = randomBoolean() ? randomRawTaskStatus() : null; - String description = randomBoolean() ? randomAlphaOfLength(5) : null; + Task.Status status = detailed ? randomRawTaskStatus() : null; + String description = detailed ? randomAlphaOfLength(5) : null; long startTime = randomLong(); long runningTimeNanos = randomLong(); boolean cancellable = randomBoolean(); @@ -242,7 +277,19 @@ static TaskInfo randomTaskInfo() { Map headers = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap(randomAlphaOfLength(5), randomAlphaOfLength(5)); - return new TaskInfo(taskId, type, action, description, status, startTime, runningTimeNanos, cancellable, parentTaskId, headers); + return new TaskInfo( + taskId, + type, + action, + description, + status, + startTime, + runningTimeNanos, + cancellable, + parentTaskId, + headers, + randomResourceStats(detailed) + ); } private static TaskId randomTaskId() { @@ -262,4 +309,14 @@ private static RawTaskStatus randomRawTaskStatus() { throw new IllegalStateException(e); } } + + public static TaskResourceStats randomResourceStats(boolean detailed) { + return detailed ? new TaskResourceStats(new HashMap() { + { + for (int i = 0; i < randomInt(5); i++) { + put(randomAlphaOfLength(5), new TaskResourceUsage(randomNonNegativeLong(), randomNonNegativeLong())); + } + } + }) : null; + } } diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 0f09b0de34206..ab49109eb8247 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -40,6 +40,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.FakeTcpChannel; @@ -59,6 +60,7 @@ import java.util.Set; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.everyItem; @@ -67,10 +69,12 @@ public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; + private AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new AtomicReference<>(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } @After diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java new file mode 100644 index 0000000000000..8ba23c5d3219c --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.HashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.tasks.ResourceStats.MEMORY; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private TaskResourceTrackingService taskResourceTrackingService; + + @Before + public void setup() { + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new AtomicReference<>()); + taskResourceTrackingService = new TaskResourceTrackingService( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + threadPool + ); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testThreadContextUpdateOnTrackingStart() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + + String key = "KEY"; + String value = "VALUE"; + + // Prepare thread context + threadPool.getThreadContext().putHeader(key, value); + threadPool.getThreadContext().putTransient(key, value); + threadPool.getThreadContext().addResponseHeader(key, value); + + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + + // All headers should be preserved and Task Id should also be included in thread context + verifyThreadContextFixedHeaders(key, value); + assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId()); + + storedContext.restore(); + + // Post restore only task id should be removed from the thread context + verifyThreadContextFixedHeaders(key, value); + assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); + } + + public void testStopTrackingHandlesCurrentActiveThread() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + long threadId = Thread.currentThread().getId(); + taskResourceTrackingService.taskExecutionStartedOnThread(task.getId(), threadId); + + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue()); + + taskResourceTrackingService.stopTracking(task); + + // Makes sure stop tracking marks the current active thread inactive and refreshes the resource stats before returning. + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + assertTrue(task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0); + } + + private void verifyThreadContextFixedHeaders(String key, String value) { + assertEquals(threadPool.getThreadContext().getHeader(key), value); + assertEquals(threadPool.getThreadContext().getTransient(key), value); + assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value); + } + +} diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java index e60871f67ea54..677ec7a0a6600 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java @@ -39,6 +39,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskAwareRequest; import org.opensearch.tasks.TaskManager; @@ -127,6 +128,21 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { super.waitForTaskCompletion(task, untilInNanos); } + @Override + public ThreadContext.StoredContext taskExecutionStarted(Task task) { + for (MockTaskManagerListener listener : listeners) { + listener.taskExecutionStarted(task, false); + } + + ThreadContext.StoredContext storedContext = super.taskExecutionStarted(task); + return () -> { + for (MockTaskManagerListener listener : listeners) { + listener.taskExecutionStarted(task, true); + } + storedContext.restore(); + }; + } + public void addListener(MockTaskManagerListener listener) { listeners.add(listener); } diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java index eb8361ac552fc..f15f878995aa2 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java @@ -43,4 +43,7 @@ public interface MockTaskManagerListener { void onTaskUnregistered(Task task); void waitForTaskCompletion(Task task); + + void taskExecutionStarted(Task task, Boolean closeableInvoked); + } diff --git a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java index 5f8611d99f0a0..2d97d5bffee01 100644 --- a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java +++ b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java @@ -40,6 +40,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; public class TestThreadPool extends ThreadPool { @@ -47,12 +48,29 @@ public class TestThreadPool extends ThreadPool { private volatile boolean returnRejectingExecutor = false; private volatile ThreadPoolExecutor rejectingExecutor; + public TestThreadPool( + String name, + AtomicReference runnableTaskListener, + ExecutorBuilder... customBuilders + ) { + this(name, Settings.EMPTY, runnableTaskListener, customBuilders); + } + public TestThreadPool(String name, ExecutorBuilder... customBuilders) { this(name, Settings.EMPTY, customBuilders); } public TestThreadPool(String name, Settings settings, ExecutorBuilder... customBuilders) { - super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), customBuilders); + this(name, settings, null, customBuilders); + } + + public TestThreadPool( + String name, + Settings settings, + AtomicReference runnableTaskListener, + ExecutorBuilder... customBuilders + ) { + super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), runnableTaskListener, customBuilders); } @Override