Skip to content

Commit

Permalink
Add grpc.server.method to WAF addresses with FQN of the grpc method
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-alvarez-alvarez committed May 28, 2024
1 parent bd6b34d commit a68ee7e
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public interface KnownAddresses {
Address<CaseInsensitiveMap<List<String>>> HEADERS_NO_COOKIES =
new Address<>("server.request.headers.no_cookies");

Address<Object> GRPC_SERVER_METHOD = new Address<>("grpc.server.method");

Address<Object> GRPC_SERVER_REQUEST_MESSAGE = new Address<>("grpc.server.request.message");

// XXX: Not really used yet, but it's a known address and we should not treat it as unknown.
Expand Down Expand Up @@ -153,6 +155,8 @@ static Address<?> forName(String name) {
return REQUEST_QUERY;
case "server.request.headers.no_cookies":
return HEADERS_NO_COOKIES;
case "grpc.server.method":
return GRPC_SERVER_METHOD;
case "grpc.server.request.message":
return GRPC_SERVER_REQUEST_MESSAGE;
case "grpc.server.request.metadata":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public class GatewayBridge {
private volatile DataSubscriberInfo requestBodySubInfo;
private volatile DataSubscriberInfo pathParamsSubInfo;
private volatile DataSubscriberInfo respDataSubInfo;
private volatile DataSubscriberInfo grpcServerMethodSubInfo;
private volatile DataSubscriberInfo grpcServerRequestMsgSubInfo;
private volatile DataSubscriberInfo graphqlServerRequestMsgSubInfo;
private volatile DataSubscriberInfo requestEndSubInfo;
Expand Down Expand Up @@ -359,6 +360,32 @@ public void init() {
return maybePublishResponseData(ctx);
});

subscriptionService.registerCallback(
EVENTS.grpcServerMethod(),
(ctx_, method) -> {
AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null) {
return NoopFlow.INSTANCE;
}
while (true) {
DataSubscriberInfo subInfo = grpcServerMethodSubInfo;
if (subInfo == null) {
subInfo = producerService.getDataSubscribers(KnownAddresses.GRPC_SERVER_METHOD);
grpcServerMethodSubInfo = subInfo;
}
if (subInfo == null || subInfo.isEmpty()) {
return NoopFlow.INSTANCE;
}
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.GRPC_SERVER_METHOD, method);
try {
return producerService.publishDataEvent(subInfo, ctx, bundle, true);
} catch (ExpiredSubscriberInfoException e) {
grpcServerMethodSubInfo = null;
}
}
});

subscriptionService.registerCallback(
EVENTS.grpcServerRequestMessage(),
(ctx_, obj) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class GatewayBridgeSpecification extends DDSpecification {
BiFunction<RequestContext, Integer, Flow<Void>> responseStartedCB
TriConsumer<RequestContext, String, String> respHeaderCB
Function<RequestContext, Flow<Void>> respHeadersDoneCB
BiFunction<RequestContext, String, Flow<Void>> grpcServerMethodCB
BiFunction<RequestContext, Object, Flow<Void>> grpcServerRequestMessageCB
BiFunction<RequestContext, Map<String, Object>, Flow<Void>> graphqlServerRequestMessageCB

Expand Down Expand Up @@ -410,6 +411,7 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * ig.registerCallback(EVENTS.responseStarted(), _) >> { responseStartedCB = it[1]; null }
1 * ig.registerCallback(EVENTS.responseHeader(), _) >> { respHeaderCB = it[1]; null }
1 * ig.registerCallback(EVENTS.responseHeaderDone(), _) >> { respHeadersDoneCB = it[1]; null }
1 * ig.registerCallback(EVENTS.grpcServerMethod(), _) >> { grpcServerMethodCB = it[1]; null }
1 * ig.registerCallback(EVENTS.grpcServerRequestMessage(), _) >> { grpcServerRequestMessageCB = it[1]; null }
1 * ig.registerCallback(EVENTS.graphqlServerRequestMessage(), _) >> { graphqlServerRequestMessageCB = it[1]; null }
0 * ig.registerCallback(_, _)
Expand Down Expand Up @@ -705,6 +707,22 @@ class GatewayBridgeSpecification extends DDSpecification {
flow.action == Flow.Action.Noop.INSTANCE
}

void 'grpc server method publishes'() {
setup:
eventDispatcher.getDataSubscribers(KnownAddresses.GRPC_SERVER_METHOD) >> nonEmptyDsInfo
DataBundle bundle

when:
Flow<?> flow = grpcServerMethodCB.apply(ctx, '/my.package.Greeter/SayHello')

then:
1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, true) >>
{ args -> bundle = args[2]; NoopFlow.INSTANCE }
bundle.get(KnownAddresses.GRPC_SERVER_METHOD) == '/my.package.Greeter/SayHello'
flow.result == null
flow.action == Flow.Action.Noop.INSTANCE
}

void 'calls trace segment post processor'() {
setup:
AgentSpan span = Stub()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
Expand Down Expand Up @@ -76,6 +77,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
if (reqContext != null) {
callIGCallbackClientAddress(cbp, reqContext, call);
callIGCallbackHeaders(cbp, reqContext, headers);
callIGCallbackGrpcServerMethod(cbp, reqContext, call.getMethodDescriptor());
}

DECORATE.afterStart(span);
Expand Down Expand Up @@ -315,6 +317,16 @@ private static void callIGCallbackRequestEnded(@Nonnull final AgentSpan span) {
}
}

private static <ReqT, RespT> void callIGCallbackGrpcServerMethod(
CallbackProvider cbp, RequestContext ctx, MethodDescriptor<ReqT, RespT> methodDescriptor) {
String method = methodDescriptor.getFullMethodName();
BiFunction<RequestContext, String, Flow<Void>> cb = cbp.getCallback(EVENTS.grpcServerMethod());
if (method == null || cb == null) {
return;
}
cb.apply(ctx, method);
}

private static void callIGCallbackGrpcMessage(@Nonnull final AgentSpan span, Object obj) {
if (obj == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {

def collectedAppSecHeaders = [:]
boolean appSecHeaderDone = false
def collectedAppSecServerMethods = []
def collectedAppSecReqMsgs = []

final Duration timeoutDuration() {
Expand Down Expand Up @@ -97,6 +98,10 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {
collectedAppSecReqMsgs << obj
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, Object, Flow<Void>>)
ig.registerCallback(EVENTS.grpcServerMethod(), { reqCtx, method ->
collectedAppSecServerMethods << method
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, String, Flow<Void>>)
}

def cleanup() {
Expand Down Expand Up @@ -230,6 +235,8 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {
traceId.toLong() as String == collectedAppSecHeaders['x-datadog-trace-id']
collectedAppSecReqMsgs.size() == 1
collectedAppSecReqMsgs.first().name == name
collectedAppSecServerMethods.size() == 1
collectedAppSecServerMethods.first() == 'example.Greeter/SayHello'

and:
if (isDataStreamsEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
Expand Down Expand Up @@ -75,6 +76,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
if (reqContext != null) {
callIGCallbackClientAddress(cbp, reqContext, call);
callIGCallbackHeaders(cbp, reqContext, headers);
callIGCallbackGrpcServerMethod(cbp, reqContext, call.getMethodDescriptor());
}

DECORATE.afterStart(span);
Expand Down Expand Up @@ -314,6 +316,16 @@ private static void callIGCallbackRequestEnded(@Nonnull final AgentSpan span) {
}
}

private static <ReqT, RespT> void callIGCallbackGrpcServerMethod(
CallbackProvider cbp, RequestContext ctx, MethodDescriptor<ReqT, RespT> methodDescriptor) {
String method = methodDescriptor.getFullMethodName();
BiFunction<RequestContext, String, Flow<Void>> cb = cbp.getCallback(EVENTS.grpcServerMethod());
if (method == null || cb == null) {
return;
}
cb.apply(ctx, method);
}

private static void callIGCallbackGrpcMessage(@Nonnull final AgentSpan span, Object obj) {
if (obj == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ abstract class GrpcTest extends VersionedNamingTestBase {
def collectedAppSecHeaders = [:]
boolean appSecHeaderDone = false
def collectedAppSecReqMsgs = []
def collectedAppSecServerMethods = []

@Override
final String service() {
Expand Down Expand Up @@ -89,6 +90,10 @@ abstract class GrpcTest extends VersionedNamingTestBase {
collectedAppSecReqMsgs << obj
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, Object, Flow<Void>>)
ig.registerCallback(EVENTS.grpcServerMethod(), { reqCtx, method ->
collectedAppSecServerMethods << method
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, String, Flow<Void>>)
}

def cleanup() {
Expand Down Expand Up @@ -217,6 +222,8 @@ abstract class GrpcTest extends VersionedNamingTestBase {
traceId.toLong() as String == collectedAppSecHeaders['x-datadog-trace-id']
collectedAppSecReqMsgs.size() == 1
collectedAppSecReqMsgs.first().name == name
collectedAppSecServerMethods.size() == 1
collectedAppSecServerMethods.first() == 'example.Greeter/SayHello'

and:
if (isDataStreamsEnabled()) {
Expand Down
11 changes: 11 additions & 0 deletions internal-api/src/main/java/datadog/trace/api/gateway/Events.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ public EventType<BiFunction<RequestContext, Object, Flow<Void>>> grpcServerReque
GRAPHQL_SERVER_REQUEST_MESSAGE;
}

static final int GRPC_SERVER_METHOD_ID = 16;

@SuppressWarnings("rawtypes")
private static final EventType GRPC_SERVER_METHOD =
new ET<>("grpc.server.method", GRPC_SERVER_METHOD_ID);

@SuppressWarnings("unchecked")
public EventType<BiFunction<RequestContext, String, Flow<Void>>> grpcServerMethod() {
return (EventType<BiFunction<RequestContext, String, Flow<Void>>>) GRPC_SERVER_METHOD;
}

static final int MAX_EVENTS = nextId.get();

private static final class ET<T> extends EventType<T> {
Expand Down

0 comments on commit a68ee7e

Please sign in to comment.