Skip to content

Commit

Permalink
Add custom span to AggregationOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-spies committed Jun 10, 2024
1 parent 730e693 commit ad3afc7
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ public StoredContext stashContextPreservingRequestHeaders(final String... reques
return stashContextPreservingRequestHeaders(Set.of(requestHeaders));
}

public ThreadContextStruct getThreadContextStruct() {
return threadLocal.get();
}

/**
* When using a {@link org.elasticsearch.telemetry.tracing.Tracer} to capture activity in Elasticsearch, when a parent span is already
* in progress, it is necessary to start a new context before beginning a child span. This method creates a context,
Expand Down Expand Up @@ -705,7 +709,7 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
}
}

private static final class ThreadContextStruct {
public static final class ThreadContextStruct {

private static final ThreadContextStruct EMPTY = new ThreadContextStruct(
Collections.emptyMap(),
Expand All @@ -714,12 +718,12 @@ private static final class ThreadContextStruct {
false
);

private final Map<String, String> requestHeaders;
private final Map<String, Object> transientHeaders;
private final Map<String, Set<String>> responseHeaders;
private final boolean isSystemContext;
public final Map<String, String> requestHeaders;
public final Map<String, Object> transientHeaders;
public final Map<String, Set<String>> responseHeaders;
public final boolean isSystemContext;
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;
public final long warningHeadersSize;

private ThreadContextStruct setSystemContext() {
if (isSystemContext) {
Expand All @@ -737,7 +741,7 @@ private ThreadContextStruct(
this(requestHeaders, responseHeaders, transientHeaders, isSystemContext, 0L);
}

private ThreadContextStruct(
public ThreadContextStruct(
Map<String, String> requestHeaders,
Map<String, Set<String>> responseHeaders,
Map<String, Object> transientHeaders,
Expand All @@ -758,7 +762,7 @@ private ThreadContextStruct() {
this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false);
}

private ThreadContextStruct putRequest(String key, String value) {
public ThreadContextStruct putRequest(String key, String value) {
Map<String, String> newRequestHeaders = new HashMap<>(this.requestHeaders);
putSingleHeader(key, value, newRequestHeaders);
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext);
Expand All @@ -770,7 +774,7 @@ private static <T> void putSingleHeader(String key, T value, Map<String, T> newH
}
}

private ThreadContextStruct putHeaders(Map<String, String> headers) {
public ThreadContextStruct putHeaders(Map<String, String> headers) {
if (headers.isEmpty()) {
return this;
} else {
Expand Down Expand Up @@ -870,7 +874,7 @@ private ThreadContextStruct putResponse(
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext, newWarningHeaderSize);
}

private ThreadContextStruct putTransient(String key, Object value) {
public ThreadContextStruct putTransient(String key, Object value) {
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
putSingleHeader(key, value, newTransient);
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;

Expand Down Expand Up @@ -53,6 +56,89 @@ public static Releasable span(ThreadPool threadPool, Tracer tracer, String name,
};
}

private static class SubThreadContext implements TraceContext {
ThreadContext.ThreadContextStruct ctx;

SubThreadContext(ThreadContext parent) {
ThreadContext.ThreadContextStruct ctx = parent.getThreadContextStruct();
final Map<String, String> newRequestHeaders = new HashMap<>(ctx.requestHeaders);
final Map<String, Object> newTransientHeaders = new HashMap<>(ctx.transientHeaders);

final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER);
if (previousTraceParent != null) {
newTransientHeaders.put("parent_" + Task.TRACE_PARENT_HTTP_HEADER, previousTraceParent);
}

final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE);
if (previousTraceState != null) {
newTransientHeaders.put("parent_" + Task.TRACE_STATE, previousTraceState);
}

final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT);
if (previousTraceContext != null) {
newTransientHeaders.put("parent_" + Task.APM_TRACE_CONTEXT, previousTraceContext);
}

// this is the context when this method returns
ThreadContext.ThreadContextStruct newContext = new ThreadContext.ThreadContextStruct(
newRequestHeaders,
ctx.responseHeaders,
newTransientHeaders,
ctx.isSystemContext,
ctx.warningHeadersSize
);

this.ctx = newContext;
}

/**
* Puts all of the given headers into this context
*/
@Override
public void putHeader(String key, String value) {
ctx.putRequest(key, value);
}


@Override
public String getHeader(String key) {
return ctx.requestHeaders.get(key);
}

/**
* Puts a transient header object into this context
*/
@Override
public void putTransient(String key, Object value) {
ctx.putTransient(key, value);
}

/**
* Returns a transient header object or <code>null</code> if there is no header for the given key
*/
@SuppressWarnings("unchecked") // (T)object
public <T> T getTransient(String key) {
return (T) ctx.transientHeaders.get(key);
}
}

public static Releasable sameThreadContextSpan(ThreadPool threadPool, Tracer tracer, String name) {
return sameThreadContextSpan(threadPool, tracer, name, Map.of());
}

public static Releasable sameThreadContextSpan(ThreadPool threadPool, Tracer tracer, String name, Map<String, Object> attributes) {
if (tracer.isEnabled() == false) {
return () -> {};
}
var span = Span.create();
SubThreadContext ctx = new SubThreadContext(threadPool.getThreadContext());
tracer.startTrace(ctx, span, name, attributes);

return () -> {
tracer.stopTrace(span);
};
}

public static <T> void span(
ThreadPool threadPool,
Tracer tracer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.telemetry.tracing.TracerSpan;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand All @@ -39,11 +43,23 @@
* been added, that is, when the {@link #finish} method has been called.
*/
public class AggregationOperator implements Operator {

private ThreadPool threadPool;
private Tracer tracer;
private boolean finished;
private Page output;
private final List<Aggregator> aggregators;
private final DriverContext driverContext;
private Releasable span;

@Override
public void setThreadPool(ThreadPool threadPool) {
this.threadPool = threadPool;
}

@Override
public void setTracer(Tracer tracer) {
this.tracer = tracer;
}

/**
* Nanoseconds this operator has spent running the aggregations.
Expand Down Expand Up @@ -90,6 +106,9 @@ public boolean needsInput() {

@Override
public void addInput(Page page) {
if (span == null) {
span = TracerSpan.sameThreadContextSpan(threadPool, tracer, "AggregationOperator");
}
long start = System.nanoTime();
checkState(needsInput(), "Operator is already finishing");
requireNonNull(page, "page is null");
Expand Down Expand Up @@ -135,6 +154,8 @@ public void finish() {
if (success == false && blocks != null) {
Releasables.closeExpectNoException(blocks);
}
span.close();
span = null;
}
}

Expand All @@ -149,7 +170,7 @@ public void close() {
if (output != null) {
Releasables.closeExpectNoException(() -> output.releaseBlocks());
}
}, Releasables.wrap(aggregators));
}, Releasables.wrap(aggregators), span);
}

private static void checkState(boolean condition, String msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.telemetry.tracing.TracerSpan;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -50,6 +53,16 @@ public class Driver implements Releasable, Describable {
public static final TimeValue DEFAULT_STATUS_INTERVAL = TimeValue.timeValueSeconds(1);

private final String sessionId;
private ThreadPool threadPool;
private Tracer tracer;

public void setTracer(Tracer tracer) {
this.tracer = tracer;
}

public void setThreadPool(ThreadPool threadPool) {
this.threadPool = threadPool;
}

/**
* The wall clock time when this driver was created in milliseconds since epoch.
Expand Down Expand Up @@ -170,36 +183,48 @@ public DriverContext driverContext() {
* thread to do other work instead of blocking or busy-spinning on the blocked operator.
*/
SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplier nowSupplier) {
long maxTimeNanos = maxTime.nanos();
long startTime = nowSupplier.getAsLong();
long nextStatus = startTime + statusNanos;
int iter = 0;
while (true) {
SubscribableListener<Void> fut = runSingleLoopIteration();
iter++;
if (fut.isDone() == false) {
updateStatus(nowSupplier.getAsLong() - startTime, iter, DriverStatus.Status.ASYNC);
return fut;
}
if (isFinished()) {
finishNanos = nowSupplier.getAsLong();
updateStatus(finishNanos - startTime, iter, DriverStatus.Status.DONE);
driverContext.finish();
Releasables.close(releasable, driverContext.getSnapshot());
return Operator.NOT_BLOCKED;
}
long now = nowSupplier.getAsLong();
if (iter >= maxIterations) {
updateStatus(now - startTime, iter, DriverStatus.Status.WAITING);
return Operator.NOT_BLOCKED;
try (
var span = TracerSpan.span(
threadPool,
tracer,
"Driver.run"
)
) {
long maxTimeNanos = maxTime.nanos();
long startTime = nowSupplier.getAsLong();
long nextStatus = startTime + statusNanos;
int iter = 0;
for (Operator op: activeOperators) {
op.setThreadPool(threadPool);
op.setTracer(tracer);
}
if (now - startTime >= maxTimeNanos) {
updateStatus(now - startTime, iter, DriverStatus.Status.WAITING);
return Operator.NOT_BLOCKED;
}
if (now > nextStatus) {
updateStatus(now - startTime, iter, DriverStatus.Status.RUNNING);
nextStatus = now + statusNanos;
while (true) {
SubscribableListener<Void> fut = runSingleLoopIteration();
iter++;
if (fut.isDone() == false) {
updateStatus(nowSupplier.getAsLong() - startTime, iter, DriverStatus.Status.ASYNC);
return fut;
}
if (isFinished()) {
finishNanos = nowSupplier.getAsLong();
updateStatus(finishNanos - startTime, iter, DriverStatus.Status.DONE);
driverContext.finish();
Releasables.close(releasable, driverContext.getSnapshot());
return Operator.NOT_BLOCKED;
}
long now = nowSupplier.getAsLong();
if (iter >= maxIterations) {
updateStatus(now - startTime, iter, DriverStatus.Status.WAITING);
return Operator.NOT_BLOCKED;
}
if (now - startTime >= maxTimeNanos) {
updateStatus(now - startTime, iter, DriverStatus.Status.WAITING);
return Operator.NOT_BLOCKED;
}
if (now > nextStatus) {
updateStatus(now - startTime, iter, DriverStatus.Status.RUNNING);
nextStatus = now + statusNanos;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.telemetry.tracing.TracerSpan;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
Expand All @@ -42,10 +45,12 @@ public DriverTaskRunner(TransportService transportService, Executor executor) {
transportService.registerRequestHandler(ACTION_NAME, executor, DriverRequest::new, new DriverRequestHandler(transportService));
}

public void executeDrivers(Task parentTask, List<Driver> drivers, Executor executor, ActionListener<Void> listener) {
public void executeDrivers(Task parentTask, List<Driver> drivers, Executor executor, ActionListener<Void> listener, Tracer tracer) {
var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> driverListener) {
driver.setTracer(tracer);
driver.setThreadPool(transportService.getThreadPool());
transportService.sendChildRequest(
transportService.getLocalNode(),
ACTION_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
package org.elasticsearch.compute.operator;

import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.telemetry.tracing.Traceable;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.telemetry.tracing.TracerSpan;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;

/**
Expand All @@ -27,6 +32,8 @@
* {@link org.elasticsearch.compute}
*/
public interface Operator extends Releasable {
default void setTracer(Tracer tracer) {}
default void setThreadPool(ThreadPool threadPool) {}
/**
* Target number of bytes in a page. By default we'll try and size pages
* so that they contain this many bytes.
Expand Down
Loading

0 comments on commit ad3afc7

Please sign in to comment.