diff --git a/contrib/google.golang.org/grpc/client.go b/contrib/google.golang.org/grpc/client.go index 78046f10d9..ebaa2907ee 100644 --- a/contrib/google.golang.org/grpc/client.go +++ b/contrib/google.golang.org/grpc/client.go @@ -45,7 +45,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if p, ok := peer.FromContext(cs.Context()); ok { setSpanTargetFromPeer(span, *p) } - defer func() { finishWithError(span, err, cs.cfg) }() + defer func() { finishWithError(span, err, cs.method, cs.cfg) }() } err = cs.ClientStream.RecvMsg(m) return err @@ -64,7 +64,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if p, ok := peer.FromContext(cs.Context()); ok { setSpanTargetFromPeer(span, *p) } - defer func() { finishWithError(span, err, cs.cfg) }() + defer func() { finishWithError(span, err, cs.method, cs.cfg) }() } err = cs.ClientStream.SendMsg(m) return err @@ -104,7 +104,7 @@ func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { return err }) if err != nil { - finishWithError(span, err, cfg) + finishWithError(span, err, method, cfg) return nil, err } @@ -116,7 +116,7 @@ func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { go func() { <-stream.Context().Done() - finishWithError(span, stream.Context().Err(), cfg) + finishWithError(span, stream.Context().Err(), method, cfg) }() } else { // if call tracing is disabled, just call streamer, but still return @@ -158,7 +158,7 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { func(ctx context.Context, opts []grpc.CallOption) error { return invoker(ctx, method, req, reply, cc, opts...) }) - finishWithError(span, err, cfg) + finishWithError(span, err, method, cfg) return err } } diff --git a/contrib/google.golang.org/grpc/grpc.go b/contrib/google.golang.org/grpc/grpc.go index 323edbb542..f6da527ece 100644 --- a/contrib/google.golang.org/grpc/grpc.go +++ b/contrib/google.golang.org/grpc/grpc.go @@ -38,6 +38,8 @@ func init() { // cache a constant option: saves one allocation per call var spanTypeRPC = tracer.SpanType(ext.AppTypeRPC) +type fullMethodNameKey struct{} + func (cfg *config) startSpanOptions(opts ...tracer.StartSpanOption) []tracer.StartSpanOption { if len(cfg.tags) == 0 && len(cfg.spanOpts) == 0 { return opts @@ -73,16 +75,17 @@ func startSpanFromContext( if sctx, err := tracer.Extract(grpcutil.MDCarrier(md)); err == nil { opts = append(opts, tracer.ChildOf(sctx)) } + ctx = context.WithValue(ctx, fullMethodNameKey{}, method) return tracer.StartSpanFromContext(ctx, operation, opts...) } // finishWithError applies finish option and a tag with gRPC status code, disregarding OK, EOF and Canceled errors. -func finishWithError(span ddtrace.Span, err error, cfg *config) { +func finishWithError(span ddtrace.Span, err error, method string, cfg *config) { if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { err = nil } errcode := status.Code(err) - if errcode == codes.OK || cfg.nonErrorCodes[errcode] { + if errcode == codes.OK || cfg.nonErrorCodes[errcode] || (cfg.errCheck != nil && cfg.errCheck(method, err)) { err = nil } span.SetTag(tagCode, errcode.String()) diff --git a/contrib/google.golang.org/grpc/grpc_test.go b/contrib/google.golang.org/grpc/grpc_test.go index 9d3313b8ee..fcf97e74d5 100644 --- a/contrib/google.golang.org/grpc/grpc_test.go +++ b/contrib/google.golang.org/grpc/grpc_test.go @@ -663,6 +663,205 @@ func waitForSpans(mt mocktracer.Tracer, sz int) { } } +func TestWithErrorCheck(t *testing.T) { + t.Run("unary", func(t *testing.T) { + for name, tt := range map[string]struct { + errCheck func(method string, err error) bool + message string + withError bool + wantCode string + wantMessage string + }{ + "Invalid_with_no_error": { + message: "invalid", + errCheck: func(method string, err error) bool { + if err == nil { + return true + } + + errCode := status.Code(err) + if errCode == codes.InvalidArgument && method == "/grpc.Fixture/Ping" { + return true + } + + return false + }, + withError: false, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + "Invalid_with_error": { + message: "invalid", + errCheck: func(method string, err error) bool { + if err == nil { + return true + } + + errCode := status.Code(err) + if errCode == codes.InvalidArgument && method == "/some/endpoint" { + return true + } + + return false + }, + withError: true, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + "Invalid_with_error_without_errCheck": { + message: "invalid", + errCheck: nil, + withError: true, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + } { + t.Run(name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + var ops []Option + if tt.errCheck != nil { + ops = append(ops, WithErrorCheck(tt.errCheck)) + } + rig, err := newRig(true, ops...) + if err != nil { + t.Fatalf("error setting up rig: %s", err) + } + + client := rig.client + _, err = client.Ping(context.Background(), &FixtureRequest{Name: tt.message}) + assert.Error(t, err) + assert.Equal(t, tt.wantCode, status.Code(err).String()) + assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) + + spans := mt.FinishedSpans() + assert.Len(t, spans, 2) + + var serverSpan, clientSpan mocktracer.Span + + for _, s := range spans { + // order of traces in buffer is not garanteed + switch s.OperationName() { + case "grpc.server": + serverSpan = s + case "grpc.client": + clientSpan = s + } + } + + if tt.withError { + assert.NotNil(t, clientSpan.Tag(ext.Error)) + assert.NotNil(t, serverSpan.Tag(ext.Error)) + } else { + assert.Nil(t, clientSpan.Tag(ext.Error)) + assert.Nil(t, serverSpan.Tag(ext.Error)) + } + + rig.Close() + mt.Reset() + }) + } + }) + + t.Run("stream", func(t *testing.T) { + for name, tt := range map[string]struct { + errCheck func(method string, err error) bool + message string + withError bool + wantCode string + wantMessage string + }{ + "Invalid_with_no_error": { + message: "invalid", + errCheck: func(method string, err error) bool { + if err == nil { + return true + } + + errCode := status.Code(err) + if errCode == codes.InvalidArgument && method == "/grpc.Fixture/StreamPing" { + return true + } + + return false + }, + withError: false, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + "Invalid_with_error": { + message: "invalid", + errCheck: func(method string, err error) bool { + if err == nil { + return true + } + + errCode := status.Code(err) + if errCode == codes.InvalidArgument && method == "/some/endpoint" { + return true + } + + return false + }, + withError: true, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + "Invalid_with_error_without_errCheck": { + message: "invalid", + errCheck: nil, + withError: true, + wantCode: codes.InvalidArgument.String(), + wantMessage: "invalid", + }, + } { + t.Run(name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + var opts []Option + if tt.errCheck != nil { + opts = append(opts, WithErrorCheck(tt.errCheck)) + } + rig, err := newRig(true, opts...) + if err != nil { + t.Fatalf("error setting up rig: %s", err) + } + + ctx, done := context.WithCancel(context.Background()) + client := rig.client + stream, err := client.StreamPing(ctx) + assert.NoError(t, err) + + err = stream.Send(&FixtureRequest{Name: tt.message}) + assert.NoError(t, err) + + _, err = stream.Recv() + assert.Error(t, err) + assert.Equal(t, tt.wantCode, status.Code(err).String()) + assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) + + assert.NoError(t, stream.CloseSend()) + done() // close stream from client side + rig.Close() + + waitForSpans(mt, 5) + + spans := mt.FinishedSpans() + assert.Len(t, spans, 5) + + for _, s := range spans { + if s.Tag(ext.Error) != nil && !tt.withError { + assert.FailNow(t, "expected no error tag on the span") + } + } + + mt.Reset() + }) + } + }) +} + func TestAnalyticsSettings(t *testing.T) { assertRate := func(t *testing.T, mt mocktracer.Tracer, rate interface{}, opts ...InterceptorOption) { rig, err := newRig(true, opts...) diff --git a/contrib/google.golang.org/grpc/option.go b/contrib/google.golang.org/grpc/option.go index 43b1ad781b..6ace965307 100644 --- a/contrib/google.golang.org/grpc/option.go +++ b/contrib/google.golang.org/grpc/option.go @@ -29,6 +29,7 @@ type config struct { serviceName func() string spanName string nonErrorCodes map[codes.Code]bool + errCheck func(method string, err error) bool traceStreamCalls bool traceStreamMessages bool noDebugStack bool @@ -126,6 +127,15 @@ func NonErrorCodes(cs ...codes.Code) InterceptorOption { } } +// WithErrorCheck sets a custom function to determine whether an error should not be considered as an error for tracing purposes. +// This function is evaluated when an error occurs, and if it returns true, the error will not be recorded in the trace. +// f: A function taking the gRPC method and error as arguments, returning a boolean to indicate if the error should be ignored. +func WithErrorCheck(f func(method string, err error) bool) Option { + return func(cfg *config) { + cfg.errCheck = f + } +} + // WithAnalytics enables Trace Analytics for all started spans. func WithAnalytics(on bool) Option { return func(cfg *config) { diff --git a/contrib/google.golang.org/grpc/server.go b/contrib/google.golang.org/grpc/server.go index a751e1a3a6..be53f7cc7a 100644 --- a/contrib/google.golang.org/grpc/server.go +++ b/contrib/google.golang.org/grpc/server.go @@ -54,7 +54,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { defer func() { withMetadataTags(ss.ctx, ss.cfg, span) withRequestTags(ss.cfg, m, span) - finishWithError(span, err, ss.cfg) + finishWithError(span, err, ss.method, ss.cfg) }() } err = ss.ServerStream.RecvMsg(m) @@ -73,7 +73,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.cfg.startSpanOptions(tracer.Measured())..., ) span.SetTag(ext.Component, componentName) - defer func() { finishWithError(span, err, ss.cfg) }() + defer func() { finishWithError(span, err, ss.method, ss.cfg) }() } err = ss.ServerStream.SendMsg(m) return err @@ -111,7 +111,7 @@ func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { case info.IsClientStream: span.SetTag(tagMethodKind, methodKindClientStream) } - defer func() { finishWithError(span, err, cfg) }() + defer func() { finishWithError(span, err, info.FullMethod, cfg) }() if appsec.Enabled() { handler = appsecStreamHandlerMiddleware(info.FullMethod, span, handler) } @@ -158,7 +158,7 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { handler = appsecUnaryHandlerMiddleware(info.FullMethod, span, handler) } resp, err := handler(ctx, req) - finishWithError(span, err, cfg) + finishWithError(span, err, info.FullMethod, cfg) return resp, err } } diff --git a/contrib/google.golang.org/grpc/stats_client.go b/contrib/google.golang.org/grpc/stats_client.go index a2c0aa43da..b1af02789e 100644 --- a/contrib/google.golang.org/grpc/stats_client.go +++ b/contrib/google.golang.org/grpc/stats_client.go @@ -58,7 +58,8 @@ func (h *clientStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { span.SetTag(ext.TargetPort, port) } case *stats.End: - finishWithError(span, rs.Error, h.cfg) + fullMethod, _ := ctx.Value(fullMethodNameKey{}).(string) + finishWithError(span, rs.Error, fullMethod, h.cfg) } } diff --git a/contrib/google.golang.org/grpc/stats_server.go b/contrib/google.golang.org/grpc/stats_server.go index 4da9316001..e979c31020 100644 --- a/contrib/google.golang.org/grpc/stats_server.go +++ b/contrib/google.golang.org/grpc/stats_server.go @@ -44,6 +44,7 @@ func (h *serverStatsHandler) TagRPC(ctx context.Context, rti *stats.RPCTagInfo) h.cfg.serviceName, spanOpts..., ) + ctx = context.WithValue(ctx, fullMethodNameKey{}, rti.FullMethodName) return ctx } @@ -53,8 +54,10 @@ func (h *serverStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { if !ok { return } + + fullMethod, _ := ctx.Value(fullMethodNameKey{}).(string) if v, ok := rs.(*stats.End); ok { - finishWithError(span, v.Error, h.cfg) + finishWithError(span, v.Error, fullMethod, h.cfg) } }