Skip to content

Commit

Permalink
Respect missing B3 sampled headers (#1512)
Browse files Browse the repository at this point in the history
* Respect missing B3 sampled headers
  • Loading branch information
tkountis committed Apr 28, 2021
1 parent 4bf74e8 commit 619c76a
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 114 deletions.
Expand Up @@ -17,6 +17,7 @@

import io.servicetalk.http.api.HttpHeaders;
import io.servicetalk.opentracing.inmemory.DefaultInMemoryTraceState;
import io.servicetalk.opentracing.inmemory.api.InMemorySpanContext;
import io.servicetalk.opentracing.inmemory.api.InMemoryTraceState;
import io.servicetalk.opentracing.inmemory.api.InMemoryTraceStateFormat;
import io.servicetalk.opentracing.internal.ZipkinHeaderNames;
Expand Down Expand Up @@ -58,14 +59,15 @@ static InMemoryTraceStateFormat<HttpHeaders> traceStateFormatter(boolean validat
}

@Override
public void inject(final InMemoryTraceState state, final HttpHeaders carrier) {
public void inject(final InMemorySpanContext context, final HttpHeaders carrier) {
final InMemoryTraceState state = context.traceState();
carrier.set(TRACE_ID, state.traceIdHex());
carrier.set(SPAN_ID, state.spanIdHex());
String parentSpanIdHex = state.parentSpanIdHex();
if (parentSpanIdHex != null) {
carrier.set(PARENT_SPAN_ID, parentSpanIdHex);
}
carrier.set(SAMPLED, state.isSampled() ? "1" : "0");
carrier.set(SAMPLED, context.isSampled() ? "1" : "0");
}

@Nullable
Expand Down Expand Up @@ -96,6 +98,6 @@ public InMemoryTraceState extract(final HttpHeaders carrier) {

CharSequence sampleId = carrier.get(SAMPLED);
return new DefaultInMemoryTraceState(traceId.toString(), spanId.toString(),
valueOf(parentSpanId), sampleId != null && sampleId.length() == 1 && sampleId.charAt(0) != '0');
valueOf(parentSpanId), sampleId != null ? (sampleId.length() == 1 && sampleId.charAt(0) != '0') : null);
}
}
Expand Up @@ -143,7 +143,7 @@ public void testInjectWithNoParent() throws Exception {
assertFalse(lastFinishedSpan.tags().containsKey(ERROR.getKey()));

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl, serverSpanState.traceId,
serverSpanState.spanId, serverSpanState.parentSpanId, TRACING_TEST_LOG_LINE_PREFIX);
serverSpanState.spanId, null, TRACING_TEST_LOG_LINE_PREFIX);
}
}
}
Expand All @@ -168,8 +168,10 @@ public void testInjectWithParent() throws Exception {
assertThat(serverSpanState.spanId, isHexId());
assertThat(serverSpanState.parentSpanId, isHexId());

assertThat(serverSpanState.traceId, equalToIgnoringCase(clientSpan.traceIdHex()));
assertThat(serverSpanState.parentSpanId, equalToIgnoringCase(clientSpan.spanIdHex()));
assertThat(serverSpanState.traceId, equalToIgnoringCase(
clientSpan.context().traceState().traceIdHex()));
assertThat(serverSpanState.parentSpanId, equalToIgnoringCase(
clientSpan.context().traceState().spanIdHex()));

// don't mess with caller span state
assertEquals(clientSpan, tracer.activeSpan());
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2019, 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import io.servicetalk.log4j2.mdc.utils.LoggerStringWriter;
import io.servicetalk.opentracing.http.TestUtils.CountingInMemorySpanEventListener;
import io.servicetalk.opentracing.inmemory.DefaultInMemoryTracer;
import io.servicetalk.opentracing.inmemory.SamplingStrategies;
import io.servicetalk.opentracing.inmemory.api.InMemorySpan;
import io.servicetalk.transport.api.ServerContext;

Expand All @@ -47,27 +48,21 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.TimeoutException;
import java.util.function.BiFunction;
import javax.annotation.Nullable;

import static io.opentracing.tag.Tags.ERROR;
import static io.opentracing.tag.Tags.HTTP_METHOD;
import static io.opentracing.tag.Tags.HTTP_STATUS;
import static io.opentracing.tag.Tags.HTTP_URL;
import static io.opentracing.tag.Tags.SPAN_KIND;
import static io.opentracing.tag.Tags.SPAN_KIND_SERVER;
import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static io.servicetalk.http.api.HttpRequestMethod.GET;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializationProviders.jsonSerializer;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.log4j2.mdc.utils.LoggerStringWriter.stableAccumulated;
import static io.servicetalk.opentracing.asynccontext.AsyncContextInMemoryScopeManager.SCOPE_MANAGER;
import static io.servicetalk.opentracing.http.TestUtils.TRACING_TEST_LOG_LINE_PREFIX;
import static io.servicetalk.opentracing.http.TestUtils.isHexId;
import static io.servicetalk.opentracing.http.TestUtils.randomHexId;
import static io.servicetalk.opentracing.http.TestUtils.verifyTraceIdPresentInLogs;
import static io.servicetalk.opentracing.internal.ZipkinHeaderNames.PARENT_SPAN_ID;
Expand All @@ -84,7 +79,6 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;
Expand All @@ -111,7 +105,13 @@ public void tearDown() {
}

private static ServerContext buildServer(CountingInMemorySpanEventListener spanListener) throws Exception {
return buildServer(spanListener, SamplingStrategies.sampleUnlessFalse());
}

private static ServerContext buildServer(CountingInMemorySpanEventListener spanListener,
BiFunction<String, Boolean, Boolean> sampler) throws Exception {
DefaultInMemoryTracer tracer = new DefaultInMemoryTracer.Builder(SCOPE_MANAGER)
.withSampler(sampler)
.addListener(spanListener).build();
return HttpServers.forAddress(localAddress(0))
.appendServiceFilter(new TracingHttpServiceFilter(tracer, "testServer"))
Expand All @@ -123,10 +123,10 @@ private static ServerContext buildServer(CountingInMemorySpanEventListener spanL
textSerializer()));
}
return succeeded(responseFactory.ok().payloadBody(from(new TestSpanState(
span.traceIdHex(),
span.spanIdHex(),
span.parentSpanIdHex(),
span.isSampled(),
span.context().traceState().traceIdHex(),
span.context().traceState().spanIdHex(),
span.context().traceState().parentSpanIdHex(),
span.context().isSampled(),
span.tags().containsKey(ERROR.getKey()))),
httpSerializer.serializerFor(TestSpanState.class)));
});
Expand All @@ -142,60 +142,80 @@ public void testRequestWithTraceKey() throws Exception {
String parentSpanId = randomHexId();
String requestUrl = "/";
HttpRequest request = client.get(requestUrl);
request.headers().set(TRACE_ID, traceId)
request.headers()
.set(TRACE_ID, traceId)
.set(SPAN_ID, spanId)
.set(PARENT_SPAN_ID, parentSpanId)
.set(SAMPLED, "0");
HttpResponse response = client.request(request).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(httpSerializer.deserializerFor(
TestSpanState.class));
assertThat(serverSpanState.traceId, equalToIgnoringCase(traceId));
assertThat(serverSpanState.spanId, not(equalToIgnoringCase(spanId)));
assertThat(serverSpanState.parentSpanId, equalToIgnoringCase(spanId));
assertFalse(serverSpanState.sampled);
assertFalse(serverSpanState.error);
assertEquals(0, spanListener.spanFinishedCount()); // not sampled, so no finish

InMemorySpan lastFinishedSpan = spanListener.lastFinishedSpan();
assertNull(lastFinishedSpan);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl, serverSpanState.traceId,
serverSpanState.spanId, serverSpanState.parentSpanId, TRACING_TEST_LOG_LINE_PREFIX);
assertSpan(spanListener, traceId, spanId, requestUrl, serverSpanState, false);
}
}
}

@Test
public void testRequestWithoutTraceKey() throws Exception {
final String requestUrl = "/foo";
public void testRequestWithTraceKeyWithoutSampled() throws Exception {
CountingInMemorySpanEventListener spanListener = new CountingInMemorySpanEventListener();
try (ServerContext context = buildServer(spanListener)) {
try (HttpClient client = forSingleAddress(serverHostAndPort(context)).build()) {
String traceId = randomHexId();
String spanId = randomHexId();
String requestUrl = "/";
HttpRequest request = client.get(requestUrl);
request.headers().set(TRACE_ID, traceId)
.set(SPAN_ID, spanId);
HttpResponse response = client.request(request).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(httpSerializer.deserializerFor(
TestSpanState.class));
assertThat(serverSpanState.traceId, isHexId());
assertThat(serverSpanState.spanId, isHexId());
assertNull(serverSpanState.parentSpanId);
assertTrue(serverSpanState.sampled);
assertFalse(serverSpanState.error);
assertEquals(1, spanListener.spanFinishedCount()); // sampled, so only finish once!

InMemorySpan lastFinishedSpan = spanListener.lastFinishedSpan();
assertNotNull(lastFinishedSpan);
assertEquals(SPAN_KIND_SERVER, lastFinishedSpan.tags().get(SPAN_KIND.getKey()));
assertEquals(GET.name(), lastFinishedSpan.tags().get(HTTP_METHOD.getKey()));
assertEquals(requestUrl, lastFinishedSpan.tags().get(HTTP_URL.getKey()));
assertEquals(OK.code(), lastFinishedSpan.tags().get(HTTP_STATUS.getKey()));
assertFalse(lastFinishedSpan.tags().containsKey(ERROR.getKey()));
assertSpan(spanListener, traceId, spanId, requestUrl, serverSpanState, true);
}
}
}

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl, serverSpanState.traceId,
serverSpanState.spanId, serverSpanState.parentSpanId, TRACING_TEST_LOG_LINE_PREFIX);
@Test
public void testRequestWithTraceKeyWithNegativeSampledAndAlwaysTrueSampler() throws Exception {
final CountingInMemorySpanEventListener spanListener = new CountingInMemorySpanEventListener();
try (ServerContext context = buildServer(spanListener, (__, ___) -> true)) {
try (HttpClient client = forSingleAddress(serverHostAndPort(context)).build()) {
String traceId = randomHexId();
String spanId = randomHexId();
String requestUrl = "/";
HttpRequest request = client.get(requestUrl);
request.headers().set(TRACE_ID, traceId)
.set(SPAN_ID, spanId)
.set(SAMPLED, "0");
HttpResponse response = client.request(request).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(httpSerializer.deserializerFor(
TestSpanState.class));
assertSpan(spanListener, traceId, spanId, requestUrl, serverSpanState, true);
}
}
}

private void assertSpan(final CountingInMemorySpanEventListener spanListener, final String traceId,
final String spanId, final String requestUrl, final TestSpanState serverSpanState,
final boolean expectedSampled)
throws InterruptedException, TimeoutException {
assertThat(serverSpanState.traceId, equalToIgnoringCase(traceId));
assertThat(serverSpanState.spanId, not(equalToIgnoringCase(spanId)));
assertThat(serverSpanState.parentSpanId, equalToIgnoringCase(spanId));
assertEquals(expectedSampled, serverSpanState.sampled);
assertFalse(serverSpanState.error);
assertEquals(expectedSampled ? 1 : 0, spanListener.spanFinishedCount());

InMemorySpan lastFinishedSpan = spanListener.lastFinishedSpan();
if (expectedSampled) {
assertNotNull(lastFinishedSpan);
} else {
assertNull(lastFinishedSpan);
}

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl, serverSpanState.traceId,
serverSpanState.spanId, serverSpanState.parentSpanId, TRACING_TEST_LOG_LINE_PREFIX);
}

@Test
public void tracerThrowsReturnsErrorResponse() throws Exception {
when(mockTracer.buildSpan(any())).thenThrow(DELIBERATE_EXCEPTION);
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright © 2018 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018, 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,7 +26,7 @@
/**
* A span that allows reading values at runtime.
*/
public interface InMemorySpan extends Span, InMemoryTraceState {
public interface InMemorySpan extends Span {
@Override
InMemorySpanContext context();

Expand All @@ -50,7 +50,7 @@ public interface InMemorySpan extends Span, InMemoryTraceState {
* @return low 64 bits of the trace ID
*/
default long traceId() {
String traceIdHex = traceIdHex();
String traceIdHex = context().traceState().traceIdHex();
return longOfHexBytes(traceIdHex, traceIdHex.length() >= 32 ? 16 : 0);
}

Expand All @@ -60,7 +60,7 @@ default long traceId() {
* @return high 64 bits of the trace ID
*/
default long traceIdHigh() {
String traceIdHex = traceIdHex();
String traceIdHex = context().traceState().traceIdHex();
return traceIdHex.length() >= 32 ? longOfHexBytes(traceIdHex, 0) : 0;
}

Expand All @@ -70,7 +70,7 @@ default long traceIdHigh() {
* @return span ID
*/
default long spanId() {
return longOfHexBytes(spanIdHex(), 0);
return longOfHexBytes(context().traceState().spanIdHex(), 0);
}

/**
Expand All @@ -80,7 +80,7 @@ default long spanId() {
*/
@Nullable
default Long parentSpanId() {
String parentSpanIdHex = parentSpanIdHex();
String parentSpanIdHex = context().traceState().parentSpanIdHex();
return parentSpanIdHex == null ? null : longOfHexBytes(parentSpanIdHex, 0);
}

Expand All @@ -90,7 +90,7 @@ default Long parentSpanId() {
* @return parent span ID in hex
*/
default String nonnullParentSpanIdHex() {
String parentSpanIdHex = parentSpanIdHex();
String parentSpanIdHex = context().traceState().parentSpanIdHex();
return parentSpanIdHex == null ? NO_PARENT_ID : parentSpanIdHex;
}

Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright © 2018 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018, 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,14 +30,12 @@ public interface InMemorySpanContext extends SpanContext {
/**
* Returns whether the span should be sampled.
* <p>
* Note this may differ from {@link InMemorySpan#isSampled()} from {@link #traceState()} if the value is overridden
* based upon some sampling policy.
* Note this may differ from {@link InMemoryTraceState#isSampled()} if the value is overridden based upon
* some sampling policy.
*
* @return whether the span should be sampled
*/
default boolean isSampled() {
return traceState().isSampled();
}
boolean isSampled();

default String toTraceId() {
return traceState().traceIdHex();
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright © 2018 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018, 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@
import javax.annotation.Nullable;

/**
* Utility for representing a Ziplin-like trace state.
* Utility for representing a Zipkin-like trace state.
*/
public interface InMemoryTraceState {
/**
Expand All @@ -42,7 +42,9 @@ public interface InMemoryTraceState {

/**
* Determine if this state is sampled.
* @return {@code true} if this state is sampled.
* @return {@code true} if this state is sampled, {@code false} if this state isn't sampled and
* {@code null} if sampling is not specified.
*/
boolean isSampled();
@Nullable
Boolean isSampled();
}
Expand Up @@ -27,10 +27,10 @@ public interface InMemoryTraceStateFormat<C> extends Format<C> {
/**
* Inject a trace state into a carrier.
*
* @param state trace state
* @param context span context
* @param carrier carrier to inject into
*/
void inject(InMemoryTraceState state, C carrier);
void inject(InMemorySpanContext context, C carrier);

/**
* Extract the trace state from a carrier.
Expand Down

0 comments on commit 619c76a

Please sign in to comment.