diff --git a/contrib/99designs/gqlgen/tracer.go b/contrib/99designs/gqlgen/tracer.go index e48ba8856f..ae541a8343 100644 --- a/contrib/99designs/gqlgen/tracer.go +++ b/contrib/99designs/gqlgen/tracer.go @@ -50,6 +50,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/namingschema" "gopkg.in/DataDog/dd-trace-go.v1/internal/telemetry" @@ -103,12 +104,12 @@ func (t *gqlTracer) Validate(_ graphql.ExecutableSchema) error { func (t *gqlTracer) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { opCtx := graphql.GetOperationContext(ctx) span, ctx := t.createRootSpan(ctx, opCtx) - ctx, req := graphqlsec.StartRequestOperation(ctx, nil /* root */, span, graphqlsec.RequestOperationArgs{ + ctx, req := graphqlsec.StartRequestOperation(ctx, nil /* root */, span, types.RequestOperationArgs{ RawQuery: opCtx.RawQuery, OperationName: opCtx.OperationName, Variables: opCtx.Variables, }) - ctx, query := graphqlsec.StartExecutionOperation(ctx, req, span, graphqlsec.ExecutionOperationArgs{ + ctx, query := graphqlsec.StartExecutionOperation(ctx, req, span, types.ExecutionOperationArgs{ Query: opCtx.RawQuery, OperationName: opCtx.OperationName, Variables: opCtx.Variables, @@ -123,11 +124,11 @@ func (t *gqlTracer) InterceptOperation(ctx context.Context, next graphql.Operati } defer span.Finish(tracer.WithError(err)) } - query.Finish(graphqlsec.ExecutionOperationRes{ + query.Finish(types.ExecutionOperationRes{ Data: response.Data, // NB - This is raw data, but rather not parse it (possibly expensive). Error: response.Errors, }) - req.Finish(graphqlsec.RequestOperationRes{ + req.Finish(types.RequestOperationRes{ Data: response.Data, // NB - This is raw data, but rather not parse it (possibly expensive). Error: response.Errors, }) @@ -150,13 +151,13 @@ func (t *gqlTracer) InterceptField(ctx context.Context, next graphql.Resolver) ( } span, ctx := tracer.StartSpanFromContext(ctx, fieldOp, opts...) defer func() { span.Finish(tracer.WithError(err)) }() - ctx, op := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*graphqlsec.ExecutionOperation](ctx), span, graphqlsec.ResolveOperationArgs{ + ctx, op := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*types.ExecutionOperation](ctx), span, types.ResolveOperationArgs{ Arguments: fieldCtx.Args, TypeName: fieldCtx.Object, FieldName: fieldCtx.Field.Name, Trivial: !(fieldCtx.IsMethod || fieldCtx.IsResolver), // TODO: Is this accurate? }) - defer func() { op.Finish(graphqlsec.ResolveOperationRes{Data: res, Error: err}) }() + defer func() { op.Finish(types.ResolveOperationRes{Data: res, Error: err}) }() res, err = next(ctx) return } diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go index e1120bbfad..a8c9c110e5 100644 --- a/contrib/google.golang.org/grpc/appsec.go +++ b/contrib/google.golang.org/grpc/appsec.go @@ -11,6 +11,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace/grpctrace" @@ -33,14 +34,16 @@ func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler) var blocked bool md, _ := metadata.FromIncomingContext(ctx) clientIP := setClientIP(ctx, span, md) - args := grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP} - ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, dyngo.NewDataListener(func(a *sharedsec.Action) { - code, e := a.GRPC()(md) - blocked = a.Blocking() - err = status.Error(codes.Code(code), e.Error()) - })) + args := types.HandlerOperationArgs{Metadata: md, ClientIP: clientIP} + ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) { + dyngo.OnData(op, func(a *sharedsec.Action) { + code, e := a.GRPC()(md) + blocked = a.Blocking() + err = status.Error(codes.Code(code), e.Error()) + }) + }) defer func() { - events := op.Finish(grpcsec.HandlerOperationRes{}) + events := op.Finish(types.HandlerOperationRes{}) if blocked { op.SetTag(trace.BlockedRequestTag, true) } @@ -54,9 +57,9 @@ func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler) if err != nil { return nil, err } - defer grpcsec.StartReceiveOperation(grpcsec.ReceiveOperationArgs{}, op).Finish(grpcsec.ReceiveOperationRes{Message: req}) + defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req}) rv, err := handler(ctx, req) - if e, ok := err.(*grpcsec.MonitoringError); ok { + if e, ok := err.(*types.MonitoringError); ok { err = status.Error(codes.Code(e.GRPCStatus()), e.Error()) } return rv, err @@ -74,18 +77,20 @@ func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler clientIP := setClientIP(ctx, span, md) grpctrace.SetRequestMetadataTags(span, md) - ctx, op := grpcsec.StartHandlerOperation(ctx, grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil, dyngo.NewDataListener(func(a *sharedsec.Action) { - code, e := a.GRPC()(md) - blocked = a.Blocking() - err = status.Error(codes.Code(code), e.Error()) - })) + ctx, op := grpcsec.StartHandlerOperation(ctx, types.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil, func(op *types.HandlerOperation) { + dyngo.OnData(op, func(a *sharedsec.Action) { + code, e := a.GRPC()(md) + blocked = a.Blocking() + err = status.Error(codes.Code(code), e.Error()) + }) + }) stream = appsecServerStream{ ServerStream: stream, handlerOperation: op, ctx: ctx, } defer func() { - events := op.Finish(grpcsec.HandlerOperationRes{}) + events := op.Finish(types.HandlerOperationRes{}) if blocked { op.SetTag(trace.BlockedRequestTag, true) } @@ -100,7 +105,7 @@ func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler } err = handler(srv, stream) - if e, ok := err.(*grpcsec.MonitoringError); ok { + if e, ok := err.(*types.MonitoringError); ok { err = status.Error(codes.Code(e.GRPCStatus()), e.Error()) } return err @@ -109,16 +114,16 @@ func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler type appsecServerStream struct { grpc.ServerStream - handlerOperation *grpcsec.HandlerOperation + handlerOperation *types.HandlerOperation ctx context.Context } // RecvMsg implements grpc.ServerStream interface method to monitor its // execution with AppSec. func (ss appsecServerStream) RecvMsg(m interface{}) error { - op := grpcsec.StartReceiveOperation(grpcsec.ReceiveOperationArgs{}, ss.handlerOperation) + op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation) defer func() { - op.Finish(grpcsec.ReceiveOperationRes{Message: m}) + op.Finish(types.ReceiveOperationRes{Message: m}) }() return ss.ServerStream.RecvMsg(m) } diff --git a/contrib/graph-gophers/graphql-go/graphql.go b/contrib/graph-gophers/graphql-go/graphql.go index 5d14cc514b..08b82ac4e8 100644 --- a/contrib/graph-gophers/graphql-go/graphql.go +++ b/contrib/graph-gophers/graphql-go/graphql.go @@ -20,6 +20,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" ddtracer "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "gopkg.in/DataDog/dd-trace-go.v1/internal/telemetry" @@ -70,12 +71,12 @@ func (t *Tracer) TraceQuery(ctx context.Context, queryString, operationName stri } span, ctx := ddtracer.StartSpanFromContext(ctx, t.cfg.querySpanName, opts...) - ctx, request := graphqlsec.StartRequestOperation(ctx, nil, span, graphqlsec.RequestOperationArgs{ + ctx, request := graphqlsec.StartRequestOperation(ctx, nil, span, types.RequestOperationArgs{ RawQuery: queryString, OperationName: operationName, Variables: variables, }) - ctx, query := graphqlsec.StartExecutionOperation(ctx, request, span, graphqlsec.ExecutionOperationArgs{ + ctx, query := graphqlsec.StartExecutionOperation(ctx, request, span, types.ExecutionOperationArgs{ Query: queryString, OperationName: operationName, Variables: variables, @@ -92,8 +93,8 @@ func (t *Tracer) TraceQuery(ctx context.Context, queryString, operationName stri err = fmt.Errorf("%s (and %d more errors)", errs[0], n-1) } defer span.Finish(ddtracer.WithError(err)) - defer request.Finish(graphqlsec.RequestOperationRes{Error: err}) - query.Finish(graphqlsec.ExecutionOperationRes{Error: err}) + defer request.Finish(types.RequestOperationRes{Error: err}) + query.Finish(types.ExecutionOperationRes{Error: err}) } } @@ -119,7 +120,7 @@ func (t *Tracer) TraceField(ctx context.Context, _, typeName, fieldName string, } span, ctx := ddtracer.StartSpanFromContext(ctx, "graphql.field", opts...) - ctx, field := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*graphqlsec.ExecutionOperation](ctx), span, graphqlsec.ResolveOperationArgs{ + ctx, field := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*types.ExecutionOperation](ctx), span, types.ResolveOperationArgs{ TypeName: typeName, FieldName: fieldName, Arguments: arguments, @@ -127,7 +128,7 @@ func (t *Tracer) TraceField(ctx context.Context, _, typeName, fieldName string, }) return ctx, func(err *errors.QueryError) { - field.Finish(graphqlsec.ResolveOperationRes{Error: err}) + field.Finish(types.ResolveOperationRes{Error: err}) // must explicitly check for nil, see issue golang/go#22729 if err != nil { diff --git a/contrib/graphql-go/graphql/graphql.go b/contrib/graphql-go/graphql/graphql.go index 79a5dd5139..1db6522435 100644 --- a/contrib/graphql-go/graphql/graphql.go +++ b/contrib/graphql-go/graphql/graphql.go @@ -15,6 +15,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/telemetry" "github.com/graphql-go/graphql" @@ -63,7 +64,7 @@ type datadogExtension struct{ config } type contextKey struct{} type contextData struct { serverSpan tracer.Span - requestOp *graphqlsec.RequestOperation + requestOp *types.RequestOperation variables map[string]any query string operationName string @@ -72,7 +73,7 @@ type contextData struct { // finish closes the top-level request operation, as well as the server span. func (c *contextData) finish(data any, err error) { defer c.serverSpan.Finish(tracer.WithError(err)) - c.requestOp.Finish(graphqlsec.RequestOperationRes{Data: data, Error: err}) + c.requestOp.Finish(types.RequestOperationRes{Data: data, Error: err}) } var extensionName = reflect.TypeOf((*datadogExtension)(nil)).Elem().Name() @@ -97,7 +98,7 @@ func (i datadogExtension) Init(ctx context.Context, params *graphql.Params) cont tracer.Tag(ext.Component, componentName), tracer.Measured(), ) - ctx, request := graphqlsec.StartRequestOperation(ctx, nil, span, graphqlsec.RequestOperationArgs{ + ctx, request := graphqlsec.StartRequestOperation(ctx, nil, span, types.RequestOperationArgs{ RawQuery: params.RequestString, Variables: params.VariableValues, OperationName: params.OperationName, @@ -192,7 +193,7 @@ func (i datadogExtension) ExecutionDidStart(ctx context.Context) (context.Contex opts = append(opts, tracer.Tag(ext.EventSampleRate, i.config.analyticsRate)) } span, ctx := tracer.StartSpanFromContext(ctx, spanExecute, opts...) - ctx, op := graphqlsec.StartExecutionOperation(ctx, graphqlsec.FromContext[*graphqlsec.RequestOperation](ctx), span, graphqlsec.ExecutionOperationArgs{ + ctx, op := graphqlsec.StartExecutionOperation(ctx, graphqlsec.FromContext[*types.RequestOperation](ctx), span, types.ExecutionOperationArgs{ Query: data.query, OperationName: data.operationName, Variables: data.variables, @@ -203,7 +204,7 @@ func (i datadogExtension) ExecutionDidStart(ctx context.Context) (context.Contex defer data.finish(result.Data, err) span.Finish(tracer.WithError(err)) }() - op.Finish(graphqlsec.ExecutionOperationRes{Data: result.Data, Error: err}) + op.Finish(types.ExecutionOperationRes{Data: result.Data, Error: err}) } } @@ -239,14 +240,14 @@ func (i datadogExtension) ResolveFieldDidStart(ctx context.Context, info *graphq opts = append(opts, tracer.Tag(ext.EventSampleRate, i.config.analyticsRate)) } span, ctx := tracer.StartSpanFromContext(ctx, spanResolve, opts...) - ctx, op := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*graphqlsec.ExecutionOperation](ctx), span, graphqlsec.ResolveOperationArgs{ + ctx, op := graphqlsec.StartResolveOperation(ctx, graphqlsec.FromContext[*types.ExecutionOperation](ctx), span, types.ResolveOperationArgs{ TypeName: info.ParentType.Name(), FieldName: info.FieldName, Arguments: collectArguments(info), }) return ctx, func(result any, err error) { defer span.Finish(tracer.WithError(err)) - op.Finish(graphqlsec.ResolveOperationRes{Error: err, Data: result}) + op.Finish(types.ResolveOperationRes{Error: err, Data: result}) } } diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go index c278e2845c..a241c40d37 100644 --- a/contrib/labstack/echo.v4/appsec.go +++ b/contrib/labstack/echo.v4/appsec.go @@ -10,6 +10,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec/types" "github.com/labstack/echo/v4" ) @@ -26,7 +27,7 @@ func withAppSec(next echo.HandlerFunc, span tracer.Span) echo.HandlerFunc { err = next(c) // If the error is a monitoring one, it means appsec actions will take care of writing the response // and handling the error. Don't call the echo error handler in this case - if _, ok := err.(*httpsec.MonitoringError); !ok && err != nil { + if _, ok := err.(*types.MonitoringError); !ok && err != nil { c.Error(err) } }) diff --git a/internal/appsec/dyngo/operation.go b/internal/appsec/dyngo/operation.go index d45492d22f..e16d357e9d 100644 --- a/internal/appsec/dyngo/operation.go +++ b/internal/appsec/dyngo/operation.go @@ -21,7 +21,6 @@ package dyngo import ( - "reflect" "sync" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" @@ -34,50 +33,33 @@ import ( // operation once it finishes so that it no longer can be called on finished // operations. type Operation interface { - // On allows to register an event listener to the operation. The event - // listener will be removed from the operation once it finishes. - On(EventListener) - - // OnData allows to register a data listener to the operation - OnData(DataListener) - - // EmitData sends data to the data listeners of the operation - EmitData(any) - - // Parent return the parent operation. It returns nil for the root - // operation. + // Parent returns the parent operation, or nil for the root operation. Parent() Operation - // emitEvent emits the event to listeners of the given argsType and calls - // them with the given op and v values. - // emitEvent is a private method implemented by the operation struct type so - // that no other package can define it. - emitEvent(argsType reflect.Type, op Operation, v interface{}) - - emitData(argsType reflect.Type, v any) + // unwrap is an internal method guaranteeing only *operation implements Operation. + unwrap() *operation +} - // add the given event listeners to the operation. - // add is a private method implemented by the operation struct type so - // that no other package can define it. - add(...EventListener) +// ArgOf marks a particular type as being the argument type of a given operation +// type. This allows this type to be listened to by an operation start listener. +// This removes the possibility of incorrectly pairing an operation and payload +// when setting up listeners, as it allows compiler-assisted coherence checks. +type ArgOf[O Operation] interface { + IsArgOf(O) +} - // finish the operation. This method allows to pass the operation value to - // use to emit the finish event. - // finish is a private method implemented by the operation struct type so - // that no other package can define it. - finish(op Operation, results interface{}) +// ResultOf marks a particular type as being the result type of a given +// operation. This allows this type to be listened to by an operation finish +// listener. +// This removes the possibility of incorrectly pairing an operation and payload +// when setting up listeners, as it allows compiler-assisted coherence checks. +type ResultOf[O Operation] interface { + IsResultOf(O) } // EventListener interface allowing to identify the Go type listened to and // dispatch calls to the underlying event listener function. -type EventListener interface { - // ListenedType returns the Go type the event listener listens to. - ListenedType() reflect.Type - // Call the underlying event listener function. The type of the value v - // is the type the event listener listens to, according to the type - // returned by ListenedType(). - Call(op Operation, v interface{}) -} +type EventListener[O Operation, T any] func(O, T) // Atomic *Operation so we can atomically read or swap it. var rootOperation atomic.Pointer[Operation] @@ -87,9 +69,9 @@ var rootOperation atomic.Pointer[Operation] // existing and running operation are still valid. func SwapRootOperation(new Operation) { rootOperation.Swap(&new) - // Note: calling FinishOperation(old) could result into mem leaks because + // Note: calling Finish(old, ...) could result into mem leaks because // some finish event listeners, possibly releasing memory and resources, - // wouldn't be called anymore (because finish() disables the operation and + // wouldn't be called anymore (because Finish() disables the operation and // removes the event listeners). } @@ -98,7 +80,7 @@ func SwapRootOperation(new Operation) { // bubble-up the operation stack, which allows listening to future events that // might happen in the operation lifetime. type operation struct { - parent Operation + parent *operation eventRegister dataBroadcaster @@ -106,21 +88,27 @@ type operation struct { mu sync.RWMutex } +func (o *operation) Parent() Operation { + return o.parent +} + +// This is the one true Operation implementation! +func (o *operation) unwrap() *operation { return o } + // NewRootOperation creates and returns a new root operation, with no parent // operation. Root operations are meant to be the top-level operation of an // operation stack, therefore receiving all the operation events. It allows to // prepare a new set of event listeners, to then atomically swap it with the // current one. func NewRootOperation() Operation { - return newOperation(nil) + return &operation{parent: nil} } // NewOperation creates and returns a new operation. It must be started by calling -// StartOperation, and finished by calling FinishOperation. The returned -// operation should be used in wrapper types to provide statically typed start -// and finish functions. The following example shows how to wrap an operation -// so that its functions are statically typed (instead of dyngo's interface{} -// values): +// StartOperation, and finished by calling Finish. The returned operation should +// be used in wrapper types to provide statically typed start and finish +// functions. The following example shows how to wrap an operation so that its +// functions are statically typed (instead of dyngo's interface{} values): // // package mypackage // import "dyngo" @@ -141,51 +129,43 @@ func NewRootOperation() Operation { // } func NewOperation(parent Operation) Operation { if parent == nil { - if root := rootOperation.Load(); root != nil { - parent = *root + if ptr := rootOperation.Load(); ptr != nil { + parent = *ptr } } - return newOperation(parent) + var parentOp *operation + if parent != nil { + parentOp = parent.unwrap() + } + return &operation{parent: parentOp} } // StartOperation starts a new operation along with its arguments and emits a // start event with the operation arguments. -func StartOperation(op Operation, args interface{}) { - argsType := reflect.TypeOf(args) +func StartOperation[O Operation, E ArgOf[O]](op O, args E) { // Bubble-up the start event starting from the parent operation as you can't // listen for your own start event - for current := op.Parent(); current != nil; current = current.Parent() { - current.emitEvent(argsType, op, args) + for current := op.unwrap().parent; current != nil; current = current.parent { + emitEvent(¤t.eventRegister, op, args) } } -func newOperation(parent Operation) *operation { - return &operation{parent: parent} -} - -// Parent return the parent operation. It returns nil for the root operation. -func (o *operation) Parent() Operation { - return o.parent -} - // FinishOperation finishes the operation along with its results and emits a // finish event with the operation results. // The operation is then disabled and its event listeners removed. -func FinishOperation(op Operation, results interface{}) { - op.finish(op, results) -} +func FinishOperation[O Operation, E ResultOf[O]](op O, results E) { + o := op.unwrap() + defer o.disable() // This will need the RLock below to be released... -func (o *operation) finish(op Operation, results interface{}) { - // Defer the call to o.disable() first so that the RWMutex gets unlocked first - defer o.disable() o.mu.RLock() defer o.mu.RUnlock() // Deferred and stacked on top of the previously deferred call to o.disable() + if o.disabled { return } - resType := reflect.TypeOf(results) - for current := op; current != nil; current = current.Parent() { - current.emitEvent(resType, op, results) + + for current := o; current != nil; current = current.parent { + emitEvent(¤t.eventRegister, op, results) } } @@ -193,55 +173,58 @@ func (o *operation) finish(op Operation, results interface{}) { func (o *operation) disable() { o.mu.Lock() defer o.mu.Unlock() + if o.disabled { return } + o.disabled = true o.eventRegister.clear() } -// Add the given event listeners to the operation. -func (o *operation) add(l ...EventListener) { +// On registers and event listener that will be called when the operation +// begins. +func On[O Operation, E ArgOf[O]](op Operation, l EventListener[O, E]) { + o := op.unwrap() + o.mu.RLock() defer o.mu.RUnlock() if o.disabled { return } - for _, l := range l { - if l == nil { - continue - } - key := l.ListenedType() - o.eventRegister.add(key, l) - } + addEventListener(&o.eventRegister, l) } -// On registers the event listener. The difference with the Register() is that -// it doesn't return a function closure, which avoids unnecessary allocations -// For example: -// -// op.On(MyOperationStart(func (op MyOperation, args MyOperationArgs) { -// // ... -// })) -func (o *operation) On(l EventListener) { +// OnFinish registers an event listener that will be called when the operation +// finishes. +func OnFinish[O Operation, E ResultOf[O]](op Operation, l EventListener[O, E]) { + o := op.unwrap() + o.mu.RLock() defer o.mu.RUnlock() if o.disabled { return } - o.eventRegister.add(l.ListenedType(), l) + addEventListener(&o.eventRegister, l) } -func (o *operation) OnData(l DataListener) { +func OnData[T any](op Operation, l DataListener[T]) { + o := op.unwrap() + o.mu.RLock() defer o.mu.RUnlock() if o.disabled { return } - o.dataBroadcaster.add(l.ListenedType(), l) + addDataListener(&o.dataBroadcaster, l) } -func (o *operation) EmitData(data any) { +// EmitData sends a data event up the operation stack. Listeners will be matched +// based on `T`. Callers may need to manually specify T when the static type of +// the value is more specific that the intended data event type. +func EmitData[T any](op Operation, data T) { + o := op.unwrap() + o.mu.RLock() defer o.mu.RUnlock() if o.disabled { @@ -250,55 +233,43 @@ func (o *operation) EmitData(data any) { // Bubble up the data to the stack of operations. Contrary to events, // we also send the data to ourselves since SDK operations are leaf operations // that both emit and listen for data (errors). - for current := Operation(o); current != nil; current = current.Parent() { - current.emitData(reflect.TypeOf(data), data) + for current := o; current != nil; current = current.parent { + emitData(¤t.dataBroadcaster, data) } } type ( // eventRegister implements a thread-safe list of event listeners. eventRegister struct { - mu sync.RWMutex listeners eventListenerMap + mu sync.RWMutex } // eventListenerMap is the map of event listeners. The list of listeners are // indexed by the operation argument or result type the event listener // expects. - eventListenerMap map[reflect.Type][]EventListener + eventListenerMap map[any][]any + + typeID[T any] struct{} dataBroadcaster struct { - mu sync.RWMutex listeners dataListenerMap + mu sync.RWMutex } - dataListenerSpec[T any] func(data T) - DataListener EventListener - dataListenerMap map[reflect.Type][]DataListener + DataListener[T any] func(T) + dataListenerMap map[any][]any ) -func (l dataListenerSpec[T]) Call(_ Operation, v interface{}) { - l(v.(T)) -} - -func (l dataListenerSpec[T]) ListenedType() reflect.Type { - return reflect.TypeOf((*T)(nil)).Elem() -} - -// NewDataListener creates a specialized generic data listener, wrapped under a DataListener interface -func NewDataListener[T any](f func(data T)) DataListener { - return dataListenerSpec[T](f) -} - -func (b *dataBroadcaster) add(key reflect.Type, l DataListener) { +func addDataListener[T any](b *dataBroadcaster, l DataListener[T]) { b.mu.Lock() defer b.mu.Unlock() if b.listeners == nil { b.listeners = make(dataListenerMap) } + key := typeID[DataListener[T]]{} b.listeners[key] = append(b.listeners[key], l) - } func (b *dataBroadcaster) clear() { @@ -307,7 +278,7 @@ func (b *dataBroadcaster) clear() { b.listeners = nil } -func (b *dataBroadcaster) emitData(key reflect.Type, v any) { +func emitData[T any](b *dataBroadcaster, v T) { defer func() { if r := recover(); r != nil { log.Error("appsec: recovered from an unexpected panic from an event listener: %+v", r) @@ -315,21 +286,20 @@ func (b *dataBroadcaster) emitData(key reflect.Type, v any) { }() b.mu.RLock() defer b.mu.RUnlock() - for t := range b.listeners { - if key == t || key.Implements(t) { - for _, listener := range b.listeners[t] { - listener.Call(nil, v) - } - } + + for _, listener := range b.listeners[typeID[DataListener[T]]{}] { + listener.(DataListener[T])(v) } } -func (r *eventRegister) add(key reflect.Type, l EventListener) { +func addEventListener[O Operation, T any](r *eventRegister, l EventListener[O, T]) { r.mu.Lock() defer r.mu.Unlock() + if r.listeners == nil { - r.listeners = make(eventListenerMap) + r.listeners = make(eventListenerMap, 2) } + key := typeID[EventListener[O, T]]{} r.listeners[key] = append(r.listeners[key], l) } @@ -339,7 +309,7 @@ func (r *eventRegister) clear() { r.listeners = nil } -func (r *eventRegister) emitEvent(key reflect.Type, op Operation, v interface{}) { +func emitEvent[O Operation, T any](r *eventRegister, op O, v T) { defer func() { if r := recover(); r != nil { log.Error("appsec: recovered from an unexpected panic from an event listener: %+v", r) @@ -347,7 +317,8 @@ func (r *eventRegister) emitEvent(key reflect.Type, op Operation, v interface{}) }() r.mu.RLock() defer r.mu.RUnlock() - for _, listener := range r.listeners[key] { - listener.Call(op, v) + + for _, listener := range r.listeners[typeID[EventListener[O, T]]{}] { + listener.(EventListener[O, T])(op, v) } } diff --git a/internal/appsec/dyngo/operation_test.go b/internal/appsec/dyngo/operation_test.go index f9dcfca7dc..08f1150c9b 100644 --- a/internal/appsec/dyngo/operation_test.go +++ b/internal/appsec/dyngo/operation_test.go @@ -19,7 +19,6 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,28 +28,19 @@ type ( RootRes struct{} ) +func (RootArgs) IsArgOf(operation) {} +func (RootRes) IsResultOf(operation) {} + type ( HTTPHandlerArgs struct { URL *url.URL Headers http.Header } - HTTPHandlerRes struct{} - OnHTTPHandlerOperationStart func(dyngo.Operation, HTTPHandlerArgs) - OnHTTPHandlerOperationFinish func(dyngo.Operation, HTTPHandlerRes) + HTTPHandlerRes struct{} ) -func (f OnHTTPHandlerOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*HTTPHandlerArgs)(nil)).Elem() -} -func (f OnHTTPHandlerOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(HTTPHandlerArgs)) -} -func (f OnHTTPHandlerOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*HTTPHandlerRes)(nil)).Elem() -} -func (f OnHTTPHandlerOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(HTTPHandlerRes)) -} +func (HTTPHandlerArgs) IsArgOf(operation) {} +func (HTTPHandlerRes) IsResultOf(operation) {} type ( SQLQueryArgs struct { @@ -59,22 +49,10 @@ type ( SQLQueryRes struct { Err error } - OnSQLQueryOperationStart func(dyngo.Operation, SQLQueryArgs) - OnSQLQueryOperationFinish func(dyngo.Operation, SQLQueryRes) ) -func (f OnSQLQueryOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*SQLQueryArgs)(nil)).Elem() -} -func (f OnSQLQueryOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(SQLQueryArgs)) -} -func (f OnSQLQueryOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*SQLQueryRes)(nil)).Elem() -} -func (f OnSQLQueryOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(SQLQueryRes)) -} +func (SQLQueryArgs) IsArgOf(operation) {} +func (SQLQueryRes) IsResultOf(operation) {} type ( GRPCHandlerArgs struct { @@ -83,22 +61,10 @@ type ( GRPCHandlerRes struct { Res interface{} } - OnGRPCHandlerOperationStart func(dyngo.Operation, GRPCHandlerArgs) - OnGRPCHandlerOperationFinish func(dyngo.Operation, GRPCHandlerRes) ) -func (f OnGRPCHandlerOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*GRPCHandlerArgs)(nil)).Elem() -} -func (f OnGRPCHandlerOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(GRPCHandlerArgs)) -} -func (f OnGRPCHandlerOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*GRPCHandlerRes)(nil)).Elem() -} -func (f OnGRPCHandlerOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(GRPCHandlerRes)) -} +func (GRPCHandlerArgs) IsArgOf(operation) {} +func (GRPCHandlerRes) IsResultOf(operation) {} type ( JSONParserArgs struct { @@ -108,22 +74,10 @@ type ( Value interface{} Err error } - OnJSONParserOperationStart func(dyngo.Operation, JSONParserArgs) - OnJSONParserOperationFinish func(dyngo.Operation, JSONParserRes) ) -func (f OnJSONParserOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*JSONParserArgs)(nil)).Elem() -} -func (f OnJSONParserOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(JSONParserArgs)) -} -func (f OnJSONParserOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*JSONParserRes)(nil)).Elem() -} -func (f OnJSONParserOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(JSONParserRes)) -} +func (JSONParserArgs) IsArgOf(operation) {} +func (JSONParserRes) IsResultOf(operation) {} type ( BodyReadArgs struct{} @@ -131,98 +85,50 @@ type ( Buf []byte Err error } - OnBodyReadOperationStart func(dyngo.Operation, BodyReadArgs) - OnBodyReadOperationFinish func(dyngo.Operation, BodyReadRes) ) -func (f OnBodyReadOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*BodyReadArgs)(nil)).Elem() -} -func (f OnBodyReadOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(BodyReadArgs)) -} -func (f OnBodyReadOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*BodyReadRes)(nil)).Elem() -} -func (f OnBodyReadOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(BodyReadRes)) -} +func (BodyReadArgs) IsArgOf(operation) {} +func (BodyReadRes) IsResultOf(operation) {} type ( - MyOperationArgs struct{ n int } - MyOperationRes struct{ n int } - OnMyOperationStart func(dyngo.Operation, MyOperationArgs) - OnMyOperationFinish func(dyngo.Operation, MyOperationRes) + MyOperationArgs struct{ n int } + MyOperationRes struct{ n int } ) -func (f OnMyOperationStart) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperationArgs)(nil)).Elem() -} -func (f OnMyOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperationArgs)) -} -func (f OnMyOperationFinish) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperationRes)(nil)).Elem() -} -func (f OnMyOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperationRes)) -} +func (MyOperationArgs) IsArgOf(operation) {} +func (MyOperationRes) IsResultOf(operation) {} type ( - MyOperation2Args struct{} - MyOperation2Res struct{} - OnMyOperation2Start func(dyngo.Operation, MyOperation2Args) - OnMyOperation2Finish func(dyngo.Operation, MyOperation2Res) + MyOperation2Args struct{} + MyOperation2Res struct{} ) -func (f OnMyOperation2Start) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperation2Args)(nil)).Elem() -} -func (f OnMyOperation2Start) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperation2Args)) -} -func (f OnMyOperation2Finish) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperation2Res)(nil)).Elem() -} -func (f OnMyOperation2Finish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperation2Res)) -} +func (MyOperation2Args) IsArgOf(operation) {} +func (MyOperation2Res) IsResultOf(operation) {} type ( - MyOperation3Args struct{} - MyOperation3Res struct{} - OnMyOperation3Start func(dyngo.Operation, MyOperation3Args) - OnMyOperation3Finish func(dyngo.Operation, MyOperation3Res) + MyOperation3Args struct{} + MyOperation3Res struct{} ) -func (f OnMyOperation3Start) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperation3Args)(nil)).Elem() -} -func (f OnMyOperation3Start) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperation3Args)) -} -func (f OnMyOperation3Finish) ListenedType() reflect.Type { - return reflect.TypeOf((*MyOperation3Res)(nil)).Elem() -} -func (f OnMyOperation3Finish) Call(op dyngo.Operation, v interface{}) { - f(op, v.(MyOperation3Res)) -} +func (MyOperation3Args) IsArgOf(operation) {} +func (MyOperation3Res) IsResultOf(operation) {} func TestUsage(t *testing.T) { t.Run("operation-stacking", func(t *testing.T) { // HTTP body read listener appending the read results to a buffer - rawBodyListener := func(called *int, buf *[]byte) dyngo.EventListener { - return OnHTTPHandlerOperationStart(func(op dyngo.Operation, _ HTTPHandlerArgs) { - op.On(OnBodyReadOperationFinish(func(op dyngo.Operation, res BodyReadRes) { + rawBodyListener := func(called *int, buf *[]byte) dyngo.EventListener[operation, HTTPHandlerArgs] { + return func(op operation, _ HTTPHandlerArgs) { + dyngo.OnFinish(op, func(op operation, res BodyReadRes) { *called++ *buf = append(*buf, res.Buf...) - })) - }) + }) + } } // Dummy waf looking for the string `attack` in HTTPHandlerArgs - wafListener := func(called *int, blocked *bool) dyngo.EventListener { - return OnHTTPHandlerOperationStart(func(op dyngo.Operation, args HTTPHandlerArgs) { + wafListener := func(called *int, blocked *bool) dyngo.EventListener[operation, HTTPHandlerArgs] { + return func(op operation, args HTTPHandlerArgs) { *called++ if strings.Contains(args.URL.RawQuery, "attack") { @@ -237,27 +143,27 @@ func TestUsage(t *testing.T) { } } } - }) + } } - jsonBodyValueListener := func(called *int, value *interface{}) dyngo.EventListener { - return OnHTTPHandlerOperationStart(func(op dyngo.Operation, _ HTTPHandlerArgs) { - op.On(OnJSONParserOperationStart(func(op dyngo.Operation, v JSONParserArgs) { + jsonBodyValueListener := func(called *int, value *interface{}) dyngo.EventListener[operation, HTTPHandlerArgs] { + return func(op operation, _ HTTPHandlerArgs) { + dyngo.On(op, func(op operation, v JSONParserArgs) { didBodyRead := false - op.On(OnBodyReadOperationStart(func(_ dyngo.Operation, _ BodyReadArgs) { + dyngo.On(op, func(_ operation, _ BodyReadArgs) { didBodyRead = true - })) + }) - op.On(OnJSONParserOperationFinish(func(op dyngo.Operation, res JSONParserRes) { + dyngo.OnFinish(op, func(op operation, res JSONParserRes) { *called++ if !didBodyRead || res.Err != nil { return } *value = res.Value - })) - })) - }) + }) + }) + } } t.Run("operation-stacking", func(t *testing.T) { @@ -282,9 +188,9 @@ func TestUsage(t *testing.T) { ) jsonBodyValueListener := jsonBodyValueListener(&JSONBodyParserCalled, &JSONBodyParserValue) - root.On(rawBodyListener) - root.On(wafListener) - root.On(jsonBodyValueListener) + dyngo.On(root, rawBodyListener) + dyngo.On(root, wafListener) + dyngo.On(root, jsonBodyValueListener) // Run the monitored stack of operations runOperation( @@ -339,9 +245,9 @@ func TestUsage(t *testing.T) { ) jsonBodyValueListener := jsonBodyValueListener(&JSONBodyParserCalled, &JSONBodyParserValue) - root.On(rawBodyListener) - root.On(wafListener) - root.On(jsonBodyValueListener) + dyngo.On(root, rawBodyListener) + dyngo.On(root, wafListener) + dyngo.On(root, jsonBodyValueListener) // Run the monitored stack of operations RawBodyBuf = nil @@ -399,9 +305,9 @@ func TestUsage(t *testing.T) { ) jsonBodyValueListener := jsonBodyValueListener(&JSONBodyParserCalled, &JSONBodyParserValue) - root.On(rawBodyListener) - root.On(wafListener) - root.On(jsonBodyValueListener) + dyngo.On(root, rawBodyListener) + dyngo.On(root, wafListener) + dyngo.On(root, jsonBodyValueListener) // Run the monitored stack of operations runOperation( @@ -434,12 +340,10 @@ func TestUsage(t *testing.T) { t.Run("recursive-operation", func(t *testing.T) { root := startOperation(RootArgs{}, nil) - defer root.Finish(RootRes{}) + defer dyngo.FinishOperation(root, RootRes{}) called := 0 - root.On(OnHTTPHandlerOperationStart(func(dyngo.Operation, HTTPHandlerArgs) { - called++ - })) + dyngo.On(root, func(operation, HTTPHandlerArgs) { called++ }) runOperation(root, HTTPHandlerArgs{}, HTTPHandlerRes{}, func(o dyngo.Operation) { runOperation(o, HTTPHandlerArgs{}, HTTPHandlerRes{}, func(o dyngo.Operation) { @@ -455,47 +359,10 @@ func TestUsage(t *testing.T) { require.Equal(t, 5, called) }) - t.Run("wrapped-operation-type-assertion", func(t *testing.T) { - // dyngo's API should allow to retrieve the actual wrapper types: an - // event listener should be called with the wrapped value. - - // Define `myop` so that it wraps a dyngo.Operation value so that it - // implements dyngo.Operation interface and we can check the event - // listeners get called with a value of type `myop`. - type myop struct { - dyngo.Operation - // count the number of calls to check the test is working as expected - called int - } - - // Create a root operation to listen for a child `myop` operation. - someRoot := dyngo.NewOperation(nil) - dyngo.StartOperation(someRoot, RootArgs{}) - defer dyngo.FinishOperation(someRoot, RootRes{}) - // Register start and finish event listeners, and type-assert that the - // passed operation has type `myop`. - someRoot.On(OnMyOperationStart(func(op dyngo.Operation, _ MyOperationArgs) { - v, ok := op.(*myop) - assert.True(t, ok) - v.called++ - })) - someRoot.On(OnMyOperationFinish(func(op dyngo.Operation, _ MyOperationRes) { - v, ok := op.(*myop) - assert.True(t, ok) - v.called++ - })) - - // Create a `myop` pointer value and start an operation with it. - op := &myop{Operation: dyngo.NewOperation(someRoot)} - dyngo.StartOperation(op, MyOperationArgs{}) - dyngo.FinishOperation(op, MyOperationRes{}) - require.Equal(t, 2, op.called) - }) - t.Run("concurrency", func(t *testing.T) { // root is the shared operation having concurrent accesses in this test root := startOperation(RootArgs{}, nil) - defer root.Finish(RootRes{}) + defer dyngo.FinishOperation(root, RootRes{}) // Create nbGoroutines registering event listeners concurrently nbGoroutines := 1000 @@ -514,8 +381,8 @@ func TestUsage(t *testing.T) { started.Done() startBarrier.Wait() defer done.Done() - root.On(OnMyOperationStart(func(dyngo.Operation, MyOperationArgs) { atomic.AddUint32(&calls, 1) })) - root.On(OnMyOperationFinish(func(dyngo.Operation, MyOperationRes) { atomic.AddUint32(&calls, 1) })) + dyngo.On(root, func(operation, MyOperationArgs) { atomic.AddUint32(&calls, 1) }) + dyngo.OnFinish(root, func(operation, MyOperationRes) { atomic.AddUint32(&calls, 1) }) }() } @@ -558,12 +425,8 @@ func TestSwapRootOperation(t *testing.T) { var onStartCalled, onFinishCalled int root := dyngo.NewRootOperation() - root.On(OnMyOperationStart(func(dyngo.Operation, MyOperationArgs) { - onStartCalled++ - })) - root.On(OnMyOperationFinish(func(dyngo.Operation, MyOperationRes) { - onFinishCalled++ - })) + dyngo.On(root, func(operation, MyOperationArgs) { onStartCalled++ }) + dyngo.OnFinish(root, func(operation, MyOperationRes) { onFinishCalled++ }) dyngo.SwapRootOperation(root) runOperation(nil, MyOperationArgs{}, MyOperationRes{}, func(op dyngo.Operation) {}) @@ -591,17 +454,14 @@ func TestSwapRootOperation(t *testing.T) { type operation struct{ dyngo.Operation } // Helper function to create an operation, wrap it and start it -func startOperation(args interface{}, parent dyngo.Operation) operation { +func startOperation[T dyngo.ArgOf[operation]](args T, parent dyngo.Operation) operation { op := operation{dyngo.NewOperation(parent)} dyngo.StartOperation(op, args) return op } -// Helper method to finish the operation -func (op operation) Finish(res interface{}) { dyngo.FinishOperation(op, res) } - // Helper function to run operations recursively. -func runOperation(parent dyngo.Operation, args, res interface{}, child func(dyngo.Operation)) { +func runOperation[A dyngo.ArgOf[operation], R dyngo.ResultOf[operation]](parent dyngo.Operation, args A, res R, child func(dyngo.Operation)) { op := startOperation(args, parent) defer dyngo.FinishOperation(op, res) if child != nil { @@ -613,42 +473,42 @@ func TestOperationData(t *testing.T) { t.Run("data-transit", func(t *testing.T) { data := 0 op := startOperation(MyOperationArgs{}, nil) - op.OnData(dyngo.NewDataListener(func(data *int) { + dyngo.OnData(op, func(data *int) { *data++ - })) + }) for i := 0; i < 10; i++ { - op.EmitData(&data) + dyngo.EmitData(op, &data) } - op.Finish(MyOperationRes{}) + dyngo.FinishOperation(op, MyOperationRes{}) require.Equal(t, 10, data) }) t.Run("bubble-up", func(t *testing.T) { - listener := dyngo.NewDataListener(func(data *int) { *data++ }) + listener := func(data *int) { *data++ } t.Run("single-listener", func(t *testing.T) { data := 0 op1 := startOperation(MyOperationArgs{}, nil) - op1.OnData(listener) + dyngo.OnData(op1, listener) op2 := startOperation(MyOperation2Args{}, op1) for i := 0; i < 10; i++ { - op2.EmitData(&data) + dyngo.EmitData(op2, &data) } - op2.Finish(MyOperation2Res{}) - op1.Finish(MyOperationRes{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) + dyngo.FinishOperation(op1, MyOperationRes{}) require.Equal(t, 10, data) }) t.Run("double-listener", func(t *testing.T) { data := 0 op1 := startOperation(MyOperationArgs{}, nil) - op1.OnData(listener) + dyngo.OnData(op1, listener) op2 := startOperation(MyOperation2Args{}, op1) - op2.OnData(listener) + dyngo.OnData(op2, listener) for i := 0; i < 10; i++ { - op2.EmitData(&data) + dyngo.EmitData(op2, &data) } - op2.Finish(MyOperation2Res{}) - op1.Finish(MyOperationRes{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) + dyngo.FinishOperation(op1, MyOperationRes{}) require.Equal(t, 20, data) }) }) @@ -659,18 +519,18 @@ func TestOperationEvents(t *testing.T) { op1 := startOperation(MyOperationArgs{}, nil) var called int - op1.On(OnMyOperation2Start(func(dyngo.Operation, MyOperation2Args) { + dyngo.On(op1, func(operation, MyOperation2Args) { called++ - })) + }) op2 := startOperation(MyOperation2Args{}, op1) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // Called once require.Equal(t, 1, called) op2 = startOperation(MyOperation2Args{}, op1) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // Called again require.Equal(t, 2, called) @@ -679,7 +539,7 @@ func TestOperationEvents(t *testing.T) { dyngo.FinishOperation(op1, MyOperationRes{}) op2 = startOperation(MyOperation2Args{}, op1) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // No longer called require.Equal(t, 2, called) @@ -689,40 +549,40 @@ func TestOperationEvents(t *testing.T) { op1 := startOperation(MyOperationArgs{}, nil) var called int - op1.On(OnMyOperation2Finish(func(dyngo.Operation, MyOperation2Res) { + dyngo.OnFinish(op1, func(operation, MyOperation2Res) { called++ - })) + }) op2 := startOperation(MyOperation2Args{}, op1) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // Called once require.Equal(t, 1, called) op2 = startOperation(MyOperation2Args{}, op1) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // Called again require.Equal(t, 2, called) op3 := startOperation(MyOperation3Args{}, op2) - op3.Finish(MyOperation3Res{}) + dyngo.FinishOperation(op3, MyOperation3Res{}) // Not called require.Equal(t, 2, called) op2 = startOperation(MyOperation2Args{}, op3) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // Called again require.Equal(t, 3, called) // Finish the operation so that it gets disabled and its listeners removed - op1.Finish(MyOperationRes{}) + dyngo.FinishOperation(op1, MyOperationRes{}) op2 = startOperation(MyOperation2Args{}, op3) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // No longer called require.Equal(t, 3, called) op2 = startOperation(MyOperation2Args{}, op2) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // No longer called require.Equal(t, 3, called) }) @@ -730,12 +590,12 @@ func TestOperationEvents(t *testing.T) { t.Run("disabled-operation-registration", func(t *testing.T) { var calls int registerTo := func(op dyngo.Operation) { - op.On(OnMyOperation2Start(func(dyngo.Operation, MyOperation2Args) { + dyngo.On(op, func(operation, MyOperation2Args) { calls++ - })) - op.On(OnMyOperation2Finish(func(dyngo.Operation, MyOperation2Res) { + }) + dyngo.OnFinish(op, func(operation, MyOperation2Res) { calls++ - })) + }) } // Start an operation and register event listeners to it. @@ -745,16 +605,16 @@ func TestOperationEvents(t *testing.T) { // Trigger the registered events op2 := startOperation(MyOperation2Args{}, op) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // We should have 4 calls require.Equal(t, 2, calls) // Finish the operation to disable it. Its event listeners should then be removed. - op.Finish(MyOperationRes{}) + dyngo.FinishOperation(op, MyOperationRes{}) // Trigger the same events op2 = startOperation(MyOperation2Args{}, op) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // The number of calls should be unchanged require.Equal(t, 2, calls) @@ -762,7 +622,7 @@ func TestOperationEvents(t *testing.T) { registerTo(op) // Trigger the same events op2 = startOperation(MyOperation2Args{}, op) - op2.Finish(MyOperation2Res{}) + dyngo.FinishOperation(op2, MyOperation2Res{}) // The number of calls should be unchanged require.Equal(t, 2, calls) }) @@ -770,39 +630,39 @@ func TestOperationEvents(t *testing.T) { t.Run("event-listener-panic", func(t *testing.T) { t.Run("start", func(t *testing.T) { op := startOperation(MyOperationArgs{}, nil) - defer op.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(op, MyOperationRes{}) // Panic on start calls := 0 - op.On(OnMyOperationStart(func(dyngo.Operation, MyOperationArgs) { + dyngo.On(op, func(operation, MyOperationArgs) { // Call counter to check we actually call this listener calls++ panic(errors.New("oops")) - })) + }) // Start the operation triggering the event: it should not panic require.NotPanics(t, func() { op := startOperation(MyOperationArgs{}, op) require.NotNil(t, op) - defer op.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(op, MyOperationRes{}) require.Equal(t, calls, 1) }) }) t.Run("finish", func(t *testing.T) { op := startOperation(MyOperationArgs{}, nil) - defer op.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(op, MyOperationRes{}) // Panic on finish calls := 0 - op.On(OnMyOperationFinish(func(dyngo.Operation, MyOperationRes) { + dyngo.OnFinish(op, func(operation, MyOperationRes) { // Call counter to check we actually call this listener calls++ panic(errors.New("oops")) - })) + }) // Run the operation triggering the finish event: it should not panic require.NotPanics(t, func() { op := startOperation(MyOperationArgs{}, op) require.NotNil(t, op) - op.Finish(MyOperationRes{}) + dyngo.FinishOperation(op, MyOperationRes{}) require.Equal(t, calls, 1) }) }) @@ -815,16 +675,16 @@ func BenchmarkEvents(b *testing.B) { for length := 1; length <= 64; length *= 2 { b.Run(fmt.Sprintf("stack=%d", length), func(b *testing.B) { root := startOperation(MyOperationArgs{}, nil) - defer root.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(root, MyOperationRes{}) op := root for i := 0; i < length-1; i++ { op = startOperation(MyOperationArgs{}, op) - defer op.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(op, MyOperationRes{}) } b.Run("start event", func(b *testing.B) { - root.On(OnMyOperationStart(func(dyngo.Operation, MyOperationArgs) {})) + dyngo.On(root, func(operation, MyOperationArgs) {}) b.ReportAllocs() b.ResetTimer() @@ -834,13 +694,13 @@ func BenchmarkEvents(b *testing.B) { }) b.Run("start + finish events", func(b *testing.B) { - root.On(OnMyOperationFinish(func(dyngo.Operation, MyOperationRes) {})) + dyngo.OnFinish(root, func(operation, MyOperationRes) {}) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { leafOp := startOperation(MyOperationArgs{}, op) - leafOp.Finish(MyOperationRes{}) + dyngo.FinishOperation(leafOp, MyOperationRes{}) } }) }) @@ -849,19 +709,19 @@ func BenchmarkEvents(b *testing.B) { b.Run("registering", func(b *testing.B) { op := startOperation(MyOperationArgs{}, nil) - defer op.Finish(MyOperationRes{}) + defer dyngo.FinishOperation(op, MyOperationRes{}) b.Run("start event", func(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - op.On(OnMyOperationStart(func(dyngo.Operation, MyOperationArgs) {})) + dyngo.On(op, func(operation, MyOperationArgs) {}) } }) b.Run("finish event", func(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - op.On(OnMyOperationFinish(func(dyngo.Operation, MyOperationRes) {})) + dyngo.OnFinish(op, func(operation, MyOperationRes) {}) } }) }) diff --git a/internal/appsec/emitter/graphqlsec/execution.go b/internal/appsec/emitter/graphqlsec/execution.go index 80c2e893e4..a9263e9913 100644 --- a/internal/appsec/emitter/graphqlsec/execution.go +++ b/internal/appsec/emitter/graphqlsec/execution.go @@ -11,80 +11,28 @@ package graphqlsec import ( "context" - "reflect" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" ) -type ExecutionOperation struct { - dyngo.Operation - trace.TagSetter - trace.SecurityEventsHolder -} - -// ExecutionOperationArgs describes arguments passed to a GraphQL query operation. -type ExecutionOperationArgs struct { - // Variables is the user-provided variables object for the query. - Variables map[string]any - // Query is the query that is being executed. - Query string - // OperationName is the user-provided operation name for the query. - OperationName string -} - // StartExecutionOperation starts a new GraphQL query operation, along with the given arguments, and // emits a start event up in the operation stack. The operation is tracked on the returned context, // and can be extracted later on using FromContext. -func StartExecutionOperation(ctx context.Context, parent *RequestOperation, span trace.TagSetter, args ExecutionOperationArgs, listeners ...dyngo.DataListener) (context.Context, *ExecutionOperation) { +func StartExecutionOperation(ctx context.Context, parent *types.RequestOperation, span trace.TagSetter, args types.ExecutionOperationArgs) (context.Context, *types.ExecutionOperation) { if span == nil { // The span may be nil (e.g: in case of GraphQL subscriptions with certian contribs). Child // operations might have spans however... and these should be used then. span = trace.NoopTagSetter{} } - op := &ExecutionOperation{ + op := &types.ExecutionOperation{ Operation: dyngo.NewOperation(parent), TagSetter: span, } - for _, l := range listeners { - op.OnData(l) - } newCtx := contextWithValue(ctx, op) dyngo.StartOperation(op, args) return newCtx, op } - -// Finish the GraphQL query operation, along with the given results, and emit a finish event up in -// the operation stack. -func (q *ExecutionOperation) Finish(res ExecutionOperationRes) { - dyngo.FinishOperation(q, res) -} - -type ( - OnExecutionOperationStart func(*ExecutionOperation, ExecutionOperationArgs) - OnExecutionOperationFinish func(*ExecutionOperation, ExecutionOperationRes) - - ExecutionOperationRes struct { - // Data is the data returned from processing the GraphQL operation. - Data any - // Error is the error returned by processing the GraphQL Operation, if any. - Error error - } -) - -var ( - executionOperationStartArgsType = reflect.TypeOf((*ExecutionOperationArgs)(nil)).Elem() - executionOperationFinishResType = reflect.TypeOf((*ExecutionOperationRes)(nil)).Elem() -) - -func (OnExecutionOperationStart) ListenedType() reflect.Type { return executionOperationStartArgsType } -func (f OnExecutionOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*ExecutionOperation), v.(ExecutionOperationArgs)) -} - -func (OnExecutionOperationFinish) ListenedType() reflect.Type { return executionOperationFinishResType } -func (f OnExecutionOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*ExecutionOperation), v.(ExecutionOperationRes)) -} diff --git a/internal/appsec/emitter/graphqlsec/init.go b/internal/appsec/emitter/graphqlsec/init.go new file mode 100644 index 0000000000..a38d7932d7 --- /dev/null +++ b/internal/appsec/emitter/graphqlsec/init.go @@ -0,0 +1,15 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package graphqlsec + +import ( + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/graphqlsec" +) + +func init() { + appsec.AddWAFEventListener(graphqlsec.Install) +} diff --git a/internal/appsec/emitter/graphqlsec/request.go b/internal/appsec/emitter/graphqlsec/request.go index e38e2f2dd7..51e137cf09 100644 --- a/internal/appsec/emitter/graphqlsec/request.go +++ b/internal/appsec/emitter/graphqlsec/request.go @@ -11,36 +11,23 @@ package graphqlsec import ( "context" - "reflect" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" ) -type RequestOperation struct { - dyngo.Operation - trace.TagSetter - trace.SecurityEventsHolder -} - -// RequestOperationArgs describes arguments passed to a GraphQL request. -type RequestOperationArgs struct { - RawQuery string // The raw, not-yet-parsed GraphQL query - OperationName string // The user-provided operation name for the query - Variables map[string]any // The user-provided variables object for this request -} - // StartRequestOperation starts a new GraphQL request operation, along with the given arguments, and // emits a start event up in the operation stack. The operation is usually linked to tge global root // operation. The operation is tracked on the returned context, and can be extracted later on using // FromContext. -func StartRequestOperation(ctx context.Context, parent dyngo.Operation, span trace.TagSetter, args RequestOperationArgs) (context.Context, *RequestOperation) { +func StartRequestOperation(ctx context.Context, parent dyngo.Operation, span trace.TagSetter, args types.RequestOperationArgs) (context.Context, *types.RequestOperation) { if span == nil { // The span may be nil (e.g: in case of GraphQL subscriptions with certian contribs) span = trace.NoopTagSetter{} } - op := &RequestOperation{ + op := &types.RequestOperation{ Operation: dyngo.NewOperation(parent), TagSetter: span, } @@ -49,36 +36,3 @@ func StartRequestOperation(ctx context.Context, parent dyngo.Operation, span tra return newCtx, op } - -// Finish the GraphQL query operation, along with the given results, and emit a finish event up in -// the operation stack. -func (q *RequestOperation) Finish(res RequestOperationRes) { - dyngo.FinishOperation(q, res) -} - -type ( - OnRequestOperationStart func(*RequestOperation, RequestOperationArgs) - OnRequestOperationFinish func(*RequestOperation, RequestOperationRes) - - RequestOperationRes struct { - // Data is the data returned from processing the GraphQL operation. - Data any - // Error is the error returned by processing the GraphQL Operation, if any. - Error error - } -) - -var ( - requestOperationStartArgsType = reflect.TypeOf((*RequestOperationArgs)(nil)).Elem() - requestOperationFinishResType = reflect.TypeOf((*RequestOperationRes)(nil)).Elem() -) - -func (OnRequestOperationStart) ListenedType() reflect.Type { return requestOperationStartArgsType } -func (f OnRequestOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*RequestOperation), v.(RequestOperationArgs)) -} - -func (OnRequestOperationFinish) ListenedType() reflect.Type { return requestOperationFinishResType } -func (f OnRequestOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*RequestOperation), v.(RequestOperationRes)) -} diff --git a/internal/appsec/emitter/graphqlsec/resolve.go b/internal/appsec/emitter/graphqlsec/resolve.go index 9f55c25ae5..967a58f72d 100644 --- a/internal/appsec/emitter/graphqlsec/resolve.go +++ b/internal/appsec/emitter/graphqlsec/resolve.go @@ -7,35 +7,17 @@ package graphqlsec import ( "context" - "reflect" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" ) -type ResolveOperation struct { - dyngo.Operation - trace.TagSetter - trace.SecurityEventsHolder -} - -// ResolveOperationArgs describes arguments passed to a GraphQL field operation. -type ResolveOperationArgs struct { - // TypeName is the name of the field's type - TypeName string - // FieldName is the name of the field - FieldName string - // Arguments is the arguments provided to the field resolver - Arguments map[string]any - // Trivial determines whether the resolution is trivial or not. Leave as false if undetermined. - Trivial bool -} - // StartResolveOperation starts a new GraphQL Resolve operation, along with the given arguments, and // emits a start event up in the operation stack. The operation is tracked on the returned context, // and can be extracted later on using FromContext. -func StartResolveOperation(ctx context.Context, parent *ExecutionOperation, span trace.TagSetter, args ResolveOperationArgs) (context.Context, *ResolveOperation) { - op := &ResolveOperation{ +func StartResolveOperation(ctx context.Context, parent *types.ExecutionOperation, span trace.TagSetter, args types.ResolveOperationArgs) (context.Context, *types.ResolveOperation) { + op := &types.ResolveOperation{ Operation: dyngo.NewOperation(parent), TagSetter: span, } @@ -44,36 +26,3 @@ func StartResolveOperation(ctx context.Context, parent *ExecutionOperation, span return newCtx, op } - -// Finish the GraphQL Field operation, along with the given results, and emit a finish event up in -// the operation stack. -func (q *ResolveOperation) Finish(res ResolveOperationRes) { - dyngo.FinishOperation(q, res) -} - -type ( - OnResolveOperationStart func(*ResolveOperation, ResolveOperationArgs) - OnResolveOperationFinish func(*ResolveOperation, ResolveOperationRes) - - ResolveOperationRes struct { - // Data is the data returned from processing the GraphQL operation. - Data any - // Error is the error returned by processing the GraphQL Operation, if any. - Error error - } -) - -var ( - resolveOperationStartArgsType = reflect.TypeOf((*ResolveOperationArgs)(nil)).Elem() - resolveOperationFinishResType = reflect.TypeOf((*ResolveOperationRes)(nil)).Elem() -) - -func (OnResolveOperationStart) ListenedType() reflect.Type { return resolveOperationStartArgsType } -func (f OnResolveOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*ResolveOperation), v.(ResolveOperationArgs)) -} - -func (OnResolveOperationFinish) ListenedType() reflect.Type { return resolveOperationFinishResType } -func (f OnResolveOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*ResolveOperation), v.(ResolveOperationRes)) -} diff --git a/internal/appsec/emitter/graphqlsec/types/types.go b/internal/appsec/emitter/graphqlsec/types/types.go new file mode 100644 index 0000000000..d8b0d1948c --- /dev/null +++ b/internal/appsec/emitter/graphqlsec/types/types.go @@ -0,0 +1,112 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package types + +import ( + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" +) + +type ( + RequestOperation struct { + dyngo.Operation + trace.TagSetter + trace.SecurityEventsHolder + } + + // RequestOperationArgs describes arguments passed to a GraphQL request. + RequestOperationArgs struct { + RawQuery string // The raw, not-yet-parsed GraphQL query + OperationName string // The user-provided operation name for the query + Variables map[string]any // The user-provided variables object for this request + } + + RequestOperationRes struct { + // Data is the data returned from processing the GraphQL operation. + Data any + // Error is the error returned by processing the GraphQL Operation, if any. + Error error + } +) + +// Finish the GraphQL query operation, along with the given results, and emit a finish event up in +// the operation stack. +func (q *RequestOperation) Finish(res RequestOperationRes) { + dyngo.FinishOperation(q, res) +} + +func (RequestOperationArgs) IsArgOf(*RequestOperation) {} +func (RequestOperationRes) IsResultOf(*RequestOperation) {} + +type ( + ExecutionOperation struct { + dyngo.Operation + trace.TagSetter + trace.SecurityEventsHolder + } + + // ExecutionOperationArgs describes arguments passed to a GraphQL query operation. + ExecutionOperationArgs struct { + // Variables is the user-provided variables object for the query. + Variables map[string]any + // Query is the query that is being executed. + Query string + // OperationName is the user-provided operation name for the query. + OperationName string + } + + ExecutionOperationRes struct { + // Data is the data returned from processing the GraphQL operation. + Data any + // Error is the error returned by processing the GraphQL Operation, if any. + Error error + } +) + +// Finish the GraphQL query operation, along with the given results, and emit a finish event up in +// the operation stack. +func (q *ExecutionOperation) Finish(res ExecutionOperationRes) { + dyngo.FinishOperation(q, res) +} + +func (ExecutionOperationArgs) IsArgOf(*ExecutionOperation) {} +func (ExecutionOperationRes) IsResultOf(*ExecutionOperation) {} + +type ( + ResolveOperation struct { + dyngo.Operation + trace.TagSetter + trace.SecurityEventsHolder + } + + // ResolveOperationArgs describes arguments passed to a GraphQL field operation. + ResolveOperationArgs struct { + // TypeName is the name of the field's type + TypeName string + // FieldName is the name of the field + FieldName string + // Arguments is the arguments provided to the field resolver + Arguments map[string]any + // Trivial determines whether the resolution is trivial or not. Leave as false if undetermined. + Trivial bool + } + + ResolveOperationRes struct { + // Data is the data returned from processing the GraphQL operation. + Data any + // Error is the error returned by processing the GraphQL Operation, if any. + Error error + } +) + +// Finish the GraphQL Field operation, along with the given results, and emit a finish event up in +// the operation stack. +func (q *ResolveOperation) Finish(res ResolveOperationRes) { + dyngo.FinishOperation(q, res) +} + +func (ResolveOperationArgs) IsArgOf(*ResolveOperation) {} +func (ResolveOperationRes) IsResultOf(*ResolveOperation) {} diff --git a/internal/appsec/emitter/grpcsec/grpc.go b/internal/appsec/emitter/grpcsec/grpc.go index 1b831f2e23..70fa5eb7af 100644 --- a/internal/appsec/emitter/grpcsec/grpc.go +++ b/internal/appsec/emitter/grpcsec/grpc.go @@ -11,199 +11,36 @@ package grpcsec import ( "context" - "reflect" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" - - "github.com/DataDog/appsec-internal-go/netip" ) -// Abstract gRPC server handler operation definitions. It is based on two -// operations allowing to describe every type of RPC: the HandlerOperation type -// which represents the RPC handler, and the ReceiveOperation type which -// represents the messages the RPC handler receives during its lifetime. -// This means that the ReceiveOperation(s) will happen within the -// HandlerOperation. -// Every type of RPC, unary, client streaming, server streaming, and -// bidirectional streaming RPCs, can be all represented with a HandlerOperation -// having one or several ReceiveOperation. -// The send operation is not required for now and therefore not defined, which -// means that server and bidirectional streaming RPCs currently have the same -// run-time representation as unary and client streaming RPCs. -type ( - // HandlerOperation represents a gRPC server handler operation. - // It must be created with StartHandlerOperation() and finished with its - // Finish() method. - // Security events observed during the operation lifetime should be added - // to the operation using its AddSecurityEvent() method. - HandlerOperation struct { - dyngo.Operation - trace.TagsHolder - trace.SecurityEventsHolder - Error error - } - // HandlerOperationArgs is the grpc handler arguments. - HandlerOperationArgs struct { - // Message received by the gRPC handler. - // Corresponds to the address `grpc.server.request.metadata`. - Metadata map[string][]string - ClientIP netip.Addr - } - // HandlerOperationRes is the grpc handler results. Empty as of today. - HandlerOperationRes struct{} - - // ReceiveOperation type representing an gRPC server handler operation. It must - // be created with StartReceiveOperation() and finished with its Finish(). - ReceiveOperation struct { - dyngo.Operation - } - // ReceiveOperationArgs is the gRPC handler receive operation arguments - // Empty as of today. - ReceiveOperationArgs struct{} - // ReceiveOperationRes is the gRPC handler receive operation results which - // contains the message the gRPC handler received. - ReceiveOperationRes struct { - // Message received by the gRPC handler. - // Corresponds to the address `grpc.server.request.message`. - Message interface{} - } - - // MonitoringError is used to vehicle a gRPC error that also embeds a request status code - MonitoringError struct { - msg string - status uint32 - } -) - -// NewMonitoringError creates and returns a new gRPC monitoring error, wrapped under -// sharedesec.MonitoringError -func NewMonitoringError(msg string, code uint32) error { - return &MonitoringError{ - msg: msg, - status: code, - } -} - -// GRPCStatus returns the gRPC status code embedded in the error -func (e *MonitoringError) GRPCStatus() uint32 { - return e.status -} - -// Error implements the error interface -func (e *MonitoringError) Error() string { - return e.msg -} - -// TODO(Julio-Guerra): create a go-generate tool to generate the types, vars and methods below - // StartHandlerOperation starts an gRPC server handler operation, along with the // given arguments and parent operation, and emits a start event up in the // operation stack. When parent is nil, the operation is linked to the global // root operation. -func StartHandlerOperation(ctx context.Context, args HandlerOperationArgs, parent dyngo.Operation, listeners ...dyngo.DataListener) (context.Context, *HandlerOperation) { - op := &HandlerOperation{ +func StartHandlerOperation(ctx context.Context, args types.HandlerOperationArgs, parent dyngo.Operation, setup ...func(*types.HandlerOperation)) (context.Context, *types.HandlerOperation) { + op := &types.HandlerOperation{ Operation: dyngo.NewOperation(parent), TagsHolder: trace.NewTagsHolder(), } - for _, l := range listeners { - op.OnData(l) - } newCtx := context.WithValue(ctx, listener.ContextKey{}, op) + for _, cb := range setup { + cb(op) + } dyngo.StartOperation(op, args) return newCtx, op } -// Finish the gRPC handler operation, along with the given results, and emit a -// finish event up in the operation stack. -func (op *HandlerOperation) Finish(res HandlerOperationRes) []any { - dyngo.FinishOperation(op, res) - return op.Events() -} - -// gRPC handler operation's start and finish event callback function types. -type ( - // OnHandlerOperationStart function type, called when an gRPC handler - // operation starts. - OnHandlerOperationStart func(*HandlerOperation, HandlerOperationArgs) - // OnHandlerOperationFinish function type, called when an gRPC handler - // operation finishes. - OnHandlerOperationFinish func(*HandlerOperation, HandlerOperationRes) -) - -var ( - handlerOperationArgsType = reflect.TypeOf((*HandlerOperationArgs)(nil)).Elem() - handlerOperationResType = reflect.TypeOf((*HandlerOperationRes)(nil)).Elem() -) - -// ListenedType returns the type a OnHandlerOperationStart event listener -// listens to, which is the HandlerOperationArgs type. -func (OnHandlerOperationStart) ListenedType() reflect.Type { return handlerOperationArgsType } - -// Call the underlying event listener function by performing the type-assertion -// on v whose type is the one returned by ListenedType(). -func (f OnHandlerOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*HandlerOperation), v.(HandlerOperationArgs)) -} - -// ListenedType returns the type a OnHandlerOperationFinish event listener -// listens to, which is the HandlerOperationRes type. -func (OnHandlerOperationFinish) ListenedType() reflect.Type { return handlerOperationResType } - -// Call the underlying event listener function by performing the type-assertion -// on v whose type is the one returned by ListenedType(). -func (f OnHandlerOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*HandlerOperation), v.(HandlerOperationRes)) -} - // StartReceiveOperation starts a receive operation of a gRPC handler, along // with the given arguments and parent operation, and emits a start event up in // the operation stack. When parent is nil, the operation is linked to the // global root operation. -func StartReceiveOperation(args ReceiveOperationArgs, parent dyngo.Operation) ReceiveOperation { - op := ReceiveOperation{Operation: dyngo.NewOperation(parent)} +func StartReceiveOperation(args types.ReceiveOperationArgs, parent dyngo.Operation) types.ReceiveOperation { + op := types.ReceiveOperation{Operation: dyngo.NewOperation(parent)} dyngo.StartOperation(op, args) return op } - -// Finish the gRPC handler operation, along with the given results, and emits a -// finish event up in the operation stack. -func (op ReceiveOperation) Finish(res ReceiveOperationRes) { - dyngo.FinishOperation(op, res) -} - -// gRPC receive operation's start and finish event callback function types. -type ( - // OnReceiveOperationStart function type, called when a gRPC receive - // operation starts. - OnReceiveOperationStart func(ReceiveOperation, ReceiveOperationArgs) - // OnReceiveOperationFinish function type, called when a grpc receive - // operation finishes. - OnReceiveOperationFinish func(ReceiveOperation, ReceiveOperationRes) -) - -var ( - receiveOperationArgsType = reflect.TypeOf((*ReceiveOperationArgs)(nil)).Elem() - receiveOperationResType = reflect.TypeOf((*ReceiveOperationRes)(nil)).Elem() -) - -// ListenedType returns the type a OnHandlerOperationStart event listener -// listens to, which is the HandlerOperationArgs type. -func (OnReceiveOperationStart) ListenedType() reflect.Type { return receiveOperationArgsType } - -// Call the underlying event listener function by performing the type-assertion -// on v whose type is the one returned by ListenedType(). -func (f OnReceiveOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(ReceiveOperation), v.(ReceiveOperationArgs)) -} - -// ListenedType returns the type a OnHandlerOperationFinish event listener -// listens to, which is the HandlerOperationRes type. -func (OnReceiveOperationFinish) ListenedType() reflect.Type { return receiveOperationResType } - -// Call the underlying event listener function by performing the type-assertion -// on v whose type is the one returned by ListenedType(). -func (f OnReceiveOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(ReceiveOperation), v.(ReceiveOperationRes)) -} diff --git a/internal/appsec/emitter/grpcsec/grpc_test.go b/internal/appsec/emitter/grpcsec/grpc_test.go index 3ac3651563..c5d8d0916d 100644 --- a/internal/appsec/emitter/grpcsec/grpc_test.go +++ b/internal/appsec/emitter/grpcsec/grpc_test.go @@ -11,19 +11,24 @@ import ( "testing" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec" + grpcsec "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec/types" "github.com/stretchr/testify/require" ) +type ( + rootArgs struct{} + rootRes struct{} +) + +func (rootArgs) IsArgOf(dyngo.Operation) {} +func (rootRes) IsResultOf(dyngo.Operation) {} + func TestUsage(t *testing.T) { testRPCRepresentation := func(expectedRecvOperation int) func(*testing.T) { return func(t *testing.T) { - type ( - rootArgs struct{} - rootRes struct{} - ) - localRootOp := dyngo.NewOperation(nil) + localRootOp := dyngo.NewRootOperation() dyngo.StartOperation(localRootOp, rootArgs{}) defer dyngo.FinishOperation(localRootOp, rootRes{}) @@ -37,34 +42,32 @@ func TestUsage(t *testing.T) { const expectedMessageFormat = "message number %d" - localRootOp.On(grpcsec.OnHandlerOperationStart(func(handlerOp *grpcsec.HandlerOperation, args grpcsec.HandlerOperationArgs) { + dyngo.On(localRootOp, func(handlerOp *types.HandlerOperation, args types.HandlerOperationArgs) { handlerStarted++ - handlerOp.On(grpcsec.OnReceiveOperationStart(func(op grpcsec.ReceiveOperation, _ grpcsec.ReceiveOperationArgs) { + dyngo.On(handlerOp, func(op types.ReceiveOperation, _ types.ReceiveOperationArgs) { recvStarted++ - op.On(grpcsec.OnReceiveOperationFinish(func(_ grpcsec.ReceiveOperation, res grpcsec.ReceiveOperationRes) { + dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) { expectedMessage := fmt.Sprintf(expectedMessageFormat, recvStarted) require.Equal(t, expectedMessage, res.Message) recvFinished++ handlerOp.AddSecurityEvents([]any{expectedMessage}) - })) - })) + }) + }) - handlerOp.On(grpcsec.OnHandlerOperationFinish(func(*grpcsec.HandlerOperation, grpcsec.HandlerOperationRes) { - handlerFinished++ - })) - })) + dyngo.OnFinish(handlerOp, func(*types.HandlerOperation, types.HandlerOperationRes) { handlerFinished++ }) + }) - _, rpcOp := grpcsec.StartHandlerOperation(context.Background(), grpcsec.HandlerOperationArgs{}, localRootOp) + _, rpcOp := grpcsec.StartHandlerOperation(context.Background(), types.HandlerOperationArgs{}, localRootOp) for i := 1; i <= expectedRecvOperation; i++ { - recvOp := grpcsec.StartReceiveOperation(grpcsec.ReceiveOperationArgs{}, rpcOp) - recvOp.Finish(grpcsec.ReceiveOperationRes{Message: fmt.Sprintf(expectedMessageFormat, i)}) + recvOp := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, rpcOp) + recvOp.Finish(types.ReceiveOperationRes{Message: fmt.Sprintf(expectedMessageFormat, i)}) } - secEvents := rpcOp.Finish(grpcsec.HandlerOperationRes{}) + secEvents := rpcOp.Finish(types.HandlerOperationRes{}) require.Len(t, secEvents, expectedRecvOperation) for i, e := range secEvents { diff --git a/internal/appsec/emitter/grpcsec/init.go b/internal/appsec/emitter/grpcsec/init.go new file mode 100644 index 0000000000..f79eda99f8 --- /dev/null +++ b/internal/appsec/emitter/grpcsec/init.go @@ -0,0 +1,15 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package grpcsec + +import ( + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/grpcsec" +) + +func init() { + appsec.AddWAFEventListener(grpcsec.Install) +} diff --git a/internal/appsec/emitter/grpcsec/types/types.go b/internal/appsec/emitter/grpcsec/types/types.go new file mode 100644 index 0000000000..7d3f47d82c --- /dev/null +++ b/internal/appsec/emitter/grpcsec/types/types.go @@ -0,0 +1,108 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package types + +import ( + "net/netip" + + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" +) + +// Abstract gRPC server handler operation definitions. It is based on two +// operations allowing to describe every type of RPC: the HandlerOperation type +// which represents the RPC handler, and the ReceiveOperation type which +// represents the messages the RPC handler receives during its lifetime. +// This means that the ReceiveOperation(s) will happen within the +// HandlerOperation. +// Every type of RPC, unary, client streaming, server streaming, and +// bidirectional streaming RPCs, can be all represented with a HandlerOperation +// having one or several ReceiveOperation. +// The send operation is not required for now and therefore not defined, which +// means that server and bidirectional streaming RPCs currently have the same +// run-time representation as unary and client streaming RPCs. +type ( + // HandlerOperation represents a gRPC server handler operation. + // It must be created with StartHandlerOperation() and finished with its + // Finish() method. + // Security events observed during the operation lifetime should be added + // to the operation using its AddSecurityEvent() method. + HandlerOperation struct { + dyngo.Operation + Error error + trace.TagsHolder + trace.SecurityEventsHolder + } + // HandlerOperationArgs is the grpc handler arguments. + HandlerOperationArgs struct { + // Message received by the gRPC handler. + // Corresponds to the address `grpc.server.request.metadata`. + Metadata map[string][]string + ClientIP netip.Addr + } + // HandlerOperationRes is the grpc handler results. Empty as of today. + HandlerOperationRes struct{} + + // ReceiveOperation type representing an gRPC server handler operation. It must + // be created with StartReceiveOperation() and finished with its Finish(). + ReceiveOperation struct { + dyngo.Operation + } + // ReceiveOperationArgs is the gRPC handler receive operation arguments + // Empty as of today. + ReceiveOperationArgs struct{} + // ReceiveOperationRes is the gRPC handler receive operation results which + // contains the message the gRPC handler received. + ReceiveOperationRes struct { + // Message received by the gRPC handler. + // Corresponds to the address `grpc.server.request.message`. + Message interface{} + } + + // MonitoringError is used to vehicle a gRPC error that also embeds a request status code + MonitoringError struct { + msg string + status uint32 + } +) + +// NewMonitoringError creates and returns a new gRPC monitoring error, wrapped under +// sharedesec.MonitoringError +func NewMonitoringError(msg string, code uint32) error { + return &MonitoringError{ + msg: msg, + status: code, + } +} + +// GRPCStatus returns the gRPC status code embedded in the error +func (e *MonitoringError) GRPCStatus() uint32 { + return e.status +} + +// Error implements the error interface +func (e *MonitoringError) Error() string { + return e.msg +} + +// Finish the gRPC handler operation, along with the given results, and emit a +// finish event up in the operation stack. +func (op *HandlerOperation) Finish(res HandlerOperationRes) []any { + dyngo.FinishOperation(op, res) + return op.Events() +} + +// Finish the gRPC handler operation, along with the given results, and emits a +// finish event up in the operation stack. +func (op ReceiveOperation) Finish(res ReceiveOperationRes) { + dyngo.FinishOperation(op, res) +} + +func (HandlerOperationArgs) IsArgOf(*HandlerOperation) {} +func (HandlerOperationRes) IsResultOf(*HandlerOperation) {} + +func (ReceiveOperationArgs) IsArgOf(ReceiveOperation) {} +func (ReceiveOperationRes) IsResultOf(ReceiveOperation) {} diff --git a/internal/appsec/emitter/httpsec/http.go b/internal/appsec/emitter/httpsec/http.go index 35c4f681fa..d721d67ec0 100644 --- a/internal/appsec/emitter/httpsec/http.go +++ b/internal/appsec/emitter/httpsec/http.go @@ -15,12 +15,11 @@ import ( // Blank import needed to use embed for the default blocked response payloads _ "embed" "net/http" - "reflect" "strings" - "sync" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" @@ -30,84 +29,29 @@ import ( "github.com/DataDog/appsec-internal-go/netip" ) -// Abstract HTTP handler operation definition. -type ( - // HandlerOperationArgs is the HTTP handler operation arguments. - HandlerOperationArgs struct { - // Method is the http method verb of the request, address is `server.request.method` - Method string - // RequestURI corresponds to the address `server.request.uri.raw` - RequestURI string - // Headers corresponds to the address `server.request.headers.no_cookies` - Headers map[string][]string - // Cookies corresponds to the address `server.request.cookies` - Cookies map[string][]string - // Query corresponds to the address `server.request.query` - Query map[string][]string - // PathParams corresponds to the address `server.request.path_params` - PathParams map[string]string - // ClientIP corresponds to the address `http.client_ip` - ClientIP netip.Addr - } - - // HandlerOperationRes is the HTTP handler operation results. - HandlerOperationRes struct { - // Status corresponds to the address `server.response.status`. - Status int - Headers map[string][]string - } - - // SDKBodyOperationArgs is the SDK body operation arguments. - SDKBodyOperationArgs struct { - // Body corresponds to the address `server.request.body`. - Body interface{} - } - - // SDKBodyOperationRes is the SDK body operation results. - SDKBodyOperationRes struct{} - - // MonitoringError is used to vehicle an HTTP error, usually resurfaced through Appsec SDKs. - MonitoringError struct { - msg string - } -) - -// Error implements the Error interface -func (e *MonitoringError) Error() string { - return e.msg -} - -// NewMonitoringError creates and returns a new HTTP monitoring error, wrapped under -// sharedesec.MonitoringError -func NewMonitoringError(msg string) error { - return &MonitoringError{ - msg: msg, - } -} - // MonitorParsedBody starts and finishes the SDK body operation. // This function should not be called when AppSec is disabled in order to // get preciser error logs. -func MonitorParsedBody(ctx context.Context, body interface{}) error { +func MonitorParsedBody(ctx context.Context, body any) error { parent := fromContext(ctx) if parent == nil { log.Error("appsec: parsed http body monitoring ignored: could not find the http handler instrumentation metadata in the request context: the request handler is not being monitored by a middleware function or the provided context is not the expected request context") return nil } - return ExecuteSDKBodyOperation(parent, SDKBodyOperationArgs{Body: body}) + return ExecuteSDKBodyOperation(parent, types.SDKBodyOperationArgs{Body: body}) } // ExecuteSDKBodyOperation starts and finishes the SDK Body operation by emitting a dyngo start and finish events // An error is returned if the body associated to that operation must be blocked -func ExecuteSDKBodyOperation(parent dyngo.Operation, args SDKBodyOperationArgs) error { +func ExecuteSDKBodyOperation(parent dyngo.Operation, args types.SDKBodyOperationArgs) error { var err error - op := &SDKBodyOperation{Operation: dyngo.NewOperation(parent)} - sharedsec.OnErrorData(op, func(e error) { + op := &types.SDKBodyOperation{Operation: dyngo.NewOperation(parent)} + dyngo.OnData(op, func(e error) { err = e }) dyngo.StartOperation(op, args) - dyngo.FinishOperation(op, SDKBodyOperationRes{}) + dyngo.FinishOperation(op, types.SDKBodyOperationRes{}) return err } @@ -128,10 +72,12 @@ func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string] var bypassHandler http.Handler var blocking bool args := MakeHandlerOperationArgs(r, clientIP, pathParams) - ctx, op := StartOperation(r.Context(), args, dyngo.NewDataListener(func(a *sharedsec.Action) { - bypassHandler = a.HTTP() - blocking = a.Blocking() - })) + ctx, op := StartOperation(r.Context(), args, func(op *types.Operation) { + dyngo.OnData(op, func(a *sharedsec.Action) { + bypassHandler = a.HTTP() + blocking = a.Blocking() + }) + }) r = r.WithContext(ctx) defer func() { @@ -170,11 +116,11 @@ func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string] } // MakeHandlerOperationArgs creates the HandlerOperationArgs value. -func MakeHandlerOperationArgs(r *http.Request, clientIP netip.Addr, pathParams map[string]string) HandlerOperationArgs { +func MakeHandlerOperationArgs(r *http.Request, clientIP netip.Addr, pathParams map[string]string) types.HandlerOperationArgs { cookies := makeCookies(r) // TODO(Julio-Guerra): avoid actively parsing the cookies thanks to dynamic instrumentation headers := headersRemoveCookies(r.Header) headers["host"] = []string{r.Host} - return HandlerOperationArgs{ + return types.HandlerOperationArgs{ Method: r.Method, RequestURI: r.RequestURI, Headers: headers, @@ -186,12 +132,12 @@ func MakeHandlerOperationArgs(r *http.Request, clientIP netip.Addr, pathParams m } // MakeHandlerOperationRes creates the HandlerOperationRes value. -func MakeHandlerOperationRes(w http.ResponseWriter) HandlerOperationRes { +func MakeHandlerOperationRes(w http.ResponseWriter) types.HandlerOperationRes { var status int if mw, ok := w.(interface{ Status() int }); ok { status = mw.Status() } - return HandlerOperationRes{Status: status, Headers: headersRemoveCookies(w.Header())} + return types.HandlerOperationRes{Status: status, Headers: headersRemoveCookies(w.Header())} } // Remove cookies from the request headers and return the map of headers @@ -222,119 +168,26 @@ func makeCookies(r *http.Request) map[string][]string { return cookies } -// TODO(Julio-Guerra): create a go-generate tool to generate the types, vars and methods below - -// Operation type representing an HTTP operation. It must be created with -// StartOperation() and finished with its Finish(). -type ( - Operation struct { - dyngo.Operation - trace.TagsHolder - trace.SecurityEventsHolder - mu sync.RWMutex - } - - // SDKBodyOperation type representing an SDK body - SDKBodyOperation struct { - dyngo.Operation - } -) - // StartOperation starts an HTTP handler operation, along with the given // context and arguments and emits a start event up in the operation stack. // The operation is linked to the global root operation since an HTTP operation // is always expected to be first in the operation stack. -func StartOperation(ctx context.Context, args HandlerOperationArgs, listeners ...dyngo.DataListener) (context.Context, *Operation) { - op := &Operation{ +func StartOperation(ctx context.Context, args types.HandlerOperationArgs, setup ...func(*types.Operation)) (context.Context, *types.Operation) { + op := &types.Operation{ Operation: dyngo.NewOperation(nil), TagsHolder: trace.NewTagsHolder(), } - for _, l := range listeners { - op.OnData(l) - } newCtx := context.WithValue(ctx, listener.ContextKey{}, op) + for _, cb := range setup { + cb(op) + } dyngo.StartOperation(op, args) return newCtx, op } // fromContext returns the Operation object stored in the context, if any -func fromContext(ctx context.Context) *Operation { +func fromContext(ctx context.Context) *types.Operation { // Avoid a runtime panic in case of type-assertion error by collecting the 2 return values - op, _ := ctx.Value(listener.ContextKey{}).(*Operation) + op, _ := ctx.Value(listener.ContextKey{}).(*types.Operation) return op } - -// Finish the HTTP handler operation, along with the given results and emits a -// finish event up in the operation stack. -func (op *Operation) Finish(res HandlerOperationRes) []any { - dyngo.FinishOperation(op, res) - return op.Events() -} - -// Finish finishes the SDKBody operation and emits a finish event -func (op *SDKBodyOperation) Finish() { - dyngo.FinishOperation(op, SDKBodyOperationRes{}) -} - -// HTTP handler operation's start and finish event callback function types. -type ( - // OnHandlerOperationStart function type, called when an HTTP handler - // operation starts. - OnHandlerOperationStart func(*Operation, HandlerOperationArgs) - // OnHandlerOperationFinish function type, called when an HTTP handler - // operation finishes. - OnHandlerOperationFinish func(*Operation, HandlerOperationRes) - // OnSDKBodyOperationStart function type, called when an SDK body - // operation starts. - OnSDKBodyOperationStart func(*SDKBodyOperation, SDKBodyOperationArgs) - // OnSDKBodyOperationFinish function type, called when an SDK body - // operation finishes. - OnSDKBodyOperationFinish func(*SDKBodyOperation, SDKBodyOperationRes) -) - -var ( - handlerOperationArgsType = reflect.TypeOf((*HandlerOperationArgs)(nil)).Elem() - handlerOperationResType = reflect.TypeOf((*HandlerOperationRes)(nil)).Elem() - sdkBodyOperationArgsType = reflect.TypeOf((*SDKBodyOperationArgs)(nil)).Elem() - sdkBodyOperationResType = reflect.TypeOf((*SDKBodyOperationRes)(nil)).Elem() -) - -// ListenedType returns the type a OnHandlerOperationStart event listener -// listens to, which is the HandlerOperationArgs type. -func (OnHandlerOperationStart) ListenedType() reflect.Type { return handlerOperationArgsType } - -// Call calls the underlying event listener function by performing the -// type-assertion on v whose type is the one returned by ListenedType(). -func (f OnHandlerOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*Operation), v.(HandlerOperationArgs)) -} - -// ListenedType returns the type a OnHandlerOperationFinish event listener -// listens to, which is the HandlerOperationRes type. -func (OnHandlerOperationFinish) ListenedType() reflect.Type { return handlerOperationResType } - -// Call calls the underlying event listener function by performing the -// type-assertion on v whose type is the one returned by ListenedType(). -func (f OnHandlerOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*Operation), v.(HandlerOperationRes)) -} - -// ListenedType returns the type a OnSDKBodyOperationStart event listener -// listens to, which is the SDKBodyOperationStartArgs type. -func (OnSDKBodyOperationStart) ListenedType() reflect.Type { return sdkBodyOperationArgsType } - -// Call calls the underlying event listener function by performing the -// type-assertion on v whose type is the one returned by ListenedType(). -func (f OnSDKBodyOperationStart) Call(op dyngo.Operation, v interface{}) { - f(op.(*SDKBodyOperation), v.(SDKBodyOperationArgs)) -} - -// ListenedType returns the type a OnSDKBodyOperationFinish event listener -// listens to, which is the SDKBodyOperationRes type. -func (OnSDKBodyOperationFinish) ListenedType() reflect.Type { return sdkBodyOperationResType } - -// Call calls the underlying event listener function by performing the -// type-assertion on v whose type is the one returned by ListenedType(). -func (f OnSDKBodyOperationFinish) Call(op dyngo.Operation, v interface{}) { - f(op.(*SDKBodyOperation), v.(SDKBodyOperationRes)) -} diff --git a/internal/appsec/emitter/httpsec/init.go b/internal/appsec/emitter/httpsec/init.go new file mode 100644 index 0000000000..9f4db28ff2 --- /dev/null +++ b/internal/appsec/emitter/httpsec/init.go @@ -0,0 +1,15 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package httpsec + +import ( + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/httpsec" +) + +func init() { + appsec.AddWAFEventListener(httpsec.Install) +} diff --git a/internal/appsec/emitter/httpsec/types/types.go b/internal/appsec/emitter/httpsec/types/types.go new file mode 100644 index 0000000000..04e481c124 --- /dev/null +++ b/internal/appsec/emitter/httpsec/types/types.go @@ -0,0 +1,103 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package types + +import ( + "net/netip" + "sync" + + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" +) + +// Operation type representing an HTTP operation. It must be created with +// StartOperation() and finished with its Finish(). +type ( + Operation struct { + dyngo.Operation + trace.TagsHolder + trace.SecurityEventsHolder + mu sync.RWMutex + } + + // SDKBodyOperation type representing an SDK body + SDKBodyOperation struct { + dyngo.Operation + } +) + +// Finish the HTTP handler operation, along with the given results and emits a +// finish event up in the operation stack. +func (op *Operation) Finish(res HandlerOperationRes) []any { + dyngo.FinishOperation(op, res) + return op.Events() +} + +// Abstract HTTP handler operation definition. +type ( + // HandlerOperationArgs is the HTTP handler operation arguments. + HandlerOperationArgs struct { + // ClientIP corresponds to the address `http.client_ip` + ClientIP netip.Addr + // Headers corresponds to the address `server.request.headers.no_cookies` + Headers map[string][]string + // Cookies corresponds to the address `server.request.cookies` + Cookies map[string][]string + // Query corresponds to the address `server.request.query` + Query map[string][]string + // PathParams corresponds to the address `server.request.path_params` + PathParams map[string]string + // Method is the http method verb of the request, address is `server.request.method` + Method string + // RequestURI corresponds to the address `server.request.uri.raw` + RequestURI string + } + + // HandlerOperationRes is the HTTP handler operation results. + HandlerOperationRes struct { + Headers map[string][]string + // Status corresponds to the address `server.response.status`. + Status int + } + + // SDKBodyOperationArgs is the SDK body operation arguments. + SDKBodyOperationArgs struct { + // Body corresponds to the address `server.request.body`. + Body interface{} + } + + // SDKBodyOperationRes is the SDK body operation results. + SDKBodyOperationRes struct{} + + // MonitoringError is used to vehicle an HTTP error, usually resurfaced through Appsec SDKs. + MonitoringError struct { + msg string + } +) + +// Finish finishes the SDKBody operation and emits a finish event +func (op *SDKBodyOperation) Finish() { + dyngo.FinishOperation(op, SDKBodyOperationRes{}) +} + +// Error implements the Error interface +func (e *MonitoringError) Error() string { + return e.msg +} + +// NewMonitoringError creates and returns a new HTTP monitoring error, wrapped under +// sharedesec.MonitoringError +func NewMonitoringError(msg string) error { + return &MonitoringError{ + msg: msg, + } +} + +func (SDKBodyOperationArgs) IsArgOf(*SDKBodyOperation) {} +func (SDKBodyOperationRes) IsResultOf(*SDKBodyOperation) {} + +func (HandlerOperationArgs) IsArgOf(*Operation) {} +func (HandlerOperationRes) IsResultOf(*Operation) {} diff --git a/internal/appsec/emitter/sharedsec/shared.go b/internal/appsec/emitter/sharedsec/shared.go index b67ccf2ac5..715afc45cd 100644 --- a/internal/appsec/emitter/sharedsec/shared.go +++ b/internal/appsec/emitter/sharedsec/shared.go @@ -39,9 +39,7 @@ var userIDOperationArgsType = reflect.TypeOf((*UserIDOperationArgs)(nil)).Elem() func ExecuteUserIDOperation(parent dyngo.Operation, args UserIDOperationArgs) error { var err error op := &UserIDOperation{Operation: dyngo.NewOperation(parent)} - OnErrorData(op, func(e error) { - err = e - }) + dyngo.OnData(op, func(e error) { err = e }) dyngo.StartOperation(op, args) dyngo.FinishOperation(op, UserIDOperationRes{}) return err @@ -69,12 +67,5 @@ func MonitorUser(ctx context.Context, userID string) error { } -// OnData is a facilitator that wraps a dyngo.Operation.OnData() call -func OnData[T any](op dyngo.Operation, f func(T)) { - op.OnData(dyngo.NewDataListener(f)) -} - -// OnErrorData is a facilitator that wraps a dyngo.Operation.OnData() call with an error type constraint -func OnErrorData[T error](op dyngo.Operation, f func(T)) { - op.OnData(dyngo.NewDataListener(f)) -} +func (UserIDOperationArgs) IsArgOf(*UserIDOperation) {} +func (UserIDOperationRes) IsResultOf(*UserIDOperation) {} diff --git a/internal/appsec/listener/graphqlsec/graphql.go b/internal/appsec/listener/graphqlsec/graphql.go index d5d28c2700..19a7762543 100644 --- a/internal/appsec/listener/graphqlsec/graphql.go +++ b/internal/appsec/listener/graphqlsec/graphql.go @@ -7,16 +7,18 @@ package graphqlsec import ( "sync" - "time" "github.com/DataDog/appsec-internal-go/limiter" waf "github.com/DataDog/go-libddwaf/v2" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/config" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/graphqlsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" - listener "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener" + shared "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "gopkg.in/DataDog/dd-trace-go.v1/internal/samplernames" ) @@ -26,71 +28,95 @@ const ( ) // List of GraphQL rule addresses currently supported by the WAF -var supportedAddresses = map[string]struct{}{ +var supportedAddresses = listener.AddressSet{ graphQLServerResolverAddr: {}, } -func SupportedAddressCount() int { - return len(supportedAddresses) +// Install registers the GraphQL WAF Event Listener on the given root operation. +func Install(wafHandle *waf.Handle, _ sharedsec.Actions, cfg *config.Config, lim limiter.Limiter, root dyngo.Operation) { + if listener := newWafEventListener(wafHandle, cfg, lim); listener != nil { + log.Debug("appsec: registering the GraphQL WAF Event Listener") + dyngo.On(root, listener.onEvent) + } } -func SupportsAddress(addr string) bool { - _, ok := supportedAddresses[addr] - return ok +type wafEventListener struct { + wafHandle *waf.Handle + config *config.Config + addresses map[string]struct{} + limiter limiter.Limiter + wafDiags waf.Diagnostics + once sync.Once +} + +func newWafEventListener(wafHandle *waf.Handle, cfg *config.Config, limiter limiter.Limiter) *wafEventListener { + if wafHandle == nil { + log.Debug("appsec: no WAF Handle available, the GraphQL WAF Event Listener will not be registered") + return nil + } + + addresses := listener.FilterAddressSet(supportedAddresses, wafHandle) + if len(addresses) == 0 { + log.Debug("appsec: no supported GraphQL address is used by currently loaded WAF rules, the GraphQL WAF Event Listener will not be registered") + return nil + } + + return &wafEventListener{ + wafHandle: wafHandle, + config: cfg, + addresses: addresses, + limiter: limiter, + wafDiags: wafHandle.Diagnostics(), + } } // NewWAFEventListener returns the WAF event listener to register in order // to enable it. -func NewWAFEventListener(handle *waf.Handle, _ sharedsec.Actions, addresses map[string]struct{}, timeout time.Duration, limiter limiter.Limiter) dyngo.EventListener { - var rulesMonitoringOnce sync.Once - wafDiags := handle.Diagnostics() - - return graphqlsec.OnRequestOperationStart(func(request *graphqlsec.RequestOperation, args graphqlsec.RequestOperationArgs) { - wafCtx := waf.NewContext(handle) - if wafCtx == nil { - return - } - - // Add span tags notifying this trace is AppSec-enabled - trace.SetAppSecEnabledTags(request) - rulesMonitoringOnce.Do(func() { - listener.AddRulesMonitoringTags(request, &wafDiags) - request.SetTag(ext.ManualKeep, samplernames.AppSec) - }) +func (l *wafEventListener) onEvent(request *types.RequestOperation, _ types.RequestOperationArgs) { + wafCtx := waf.NewContext(l.wafHandle) + if wafCtx == nil { + return + } + + // Add span tags notifying this trace is AppSec-enabled + trace.SetAppSecEnabledTags(request) + l.once.Do(func() { + shared.AddRulesMonitoringTags(request, &l.wafDiags) + request.SetTag(ext.ManualKeep, samplernames.AppSec) + }) - request.On(graphqlsec.OnExecutionOperationStart(func(query *graphqlsec.ExecutionOperation, args graphqlsec.ExecutionOperationArgs) { - query.On(graphqlsec.OnResolveOperationStart(func(field *graphqlsec.ResolveOperation, args graphqlsec.ResolveOperationArgs) { - if _, found := addresses[graphQLServerResolverAddr]; found { - wafResult := listener.RunWAF( - wafCtx, - waf.RunAddressData{ - Ephemeral: map[string]any{ - graphQLServerResolverAddr: map[string]any{args.FieldName: args.Arguments}, - }, + dyngo.On(request, func(query *types.ExecutionOperation, args types.ExecutionOperationArgs) { + dyngo.On(query, func(field *types.ResolveOperation, args types.ResolveOperationArgs) { + if _, found := l.addresses[graphQLServerResolverAddr]; found { + wafResult := shared.RunWAF( + wafCtx, + waf.RunAddressData{ + Ephemeral: map[string]any{ + graphQLServerResolverAddr: map[string]any{args.FieldName: args.Arguments}, }, - timeout, - ) - listener.AddSecurityEvents(field, limiter, wafResult.Events) - } - - field.On(graphqlsec.OnResolveOperationFinish(func(field *graphqlsec.ResolveOperation, res graphqlsec.ResolveOperationRes) { - trace.SetEventSpanTags(field, field.Events()) - })) - })) - - query.On(graphqlsec.OnExecutionOperationFinish(func(query *graphqlsec.ExecutionOperation, res graphqlsec.ExecutionOperationRes) { - trace.SetEventSpanTags(query, query.Events()) - })) - })) - - request.On(graphqlsec.OnRequestOperationFinish(func(request *graphqlsec.RequestOperation, res graphqlsec.RequestOperationRes) { - defer wafCtx.Close() - - overall, internal := wafCtx.TotalRuntime() - nbTimeouts := wafCtx.TotalTimeouts() - listener.AddWAFMonitoringTags(request, wafDiags.Version, overall, internal, nbTimeouts) - - trace.SetEventSpanTags(request, request.Events()) - })) + }, + l.config.WAFTimeout, + ) + shared.AddSecurityEvents(field, l.limiter, wafResult.Events) + } + + dyngo.OnFinish(field, func(field *types.ResolveOperation, res types.ResolveOperationRes) { + trace.SetEventSpanTags(field, field.Events()) + }) + }) + + dyngo.OnFinish(query, func(query *types.ExecutionOperation, res types.ExecutionOperationRes) { + trace.SetEventSpanTags(query, query.Events()) + }) + }) + + dyngo.OnFinish(request, func(request *types.RequestOperation, res types.RequestOperationRes) { + defer wafCtx.Close() + + overall, internal := wafCtx.TotalRuntime() + nbTimeouts := wafCtx.TotalTimeouts() + shared.AddWAFMonitoringTags(request, l.wafDiags.Version, overall, internal, nbTimeouts) + + trace.SetEventSpanTags(request, request.Events()) }) } diff --git a/internal/appsec/listener/grpcsec/grpc.go b/internal/appsec/listener/grpcsec/grpc.go index a292f035c2..fb401fbf49 100644 --- a/internal/appsec/listener/grpcsec/grpc.go +++ b/internal/appsec/listener/grpcsec/grpc.go @@ -7,18 +7,19 @@ package grpcsec import ( "sync" - "time" "go.uber.org/atomic" "github.com/DataDog/appsec-internal-go/limiter" waf "github.com/DataDog/go-libddwaf/v2" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/config" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/grpcsec/types" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/httpsec" - listener "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" + shared "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "gopkg.in/DataDog/dd-trace-go.v1/internal/samplernames" ) @@ -32,131 +33,157 @@ const ( ) // List of gRPC rule addresses currently supported by the WAF -var supportedAddresses = map[string]struct{}{ +var supportedAddresses = listener.AddressSet{ GRPCServerRequestMessage: {}, GRPCServerRequestMetadata: {}, HTTPClientIPAddr: {}, UserIDAddr: {}, } -func SupportedAddressCount() int { - return len(supportedAddresses) +// Install registers the gRPC WAF Event Listener on the given root operation. +func Install(wafHandle *waf.Handle, actions sharedsec.Actions, cfg *config.Config, lim limiter.Limiter, root dyngo.Operation) { + if listener := newWafEventListener(wafHandle, actions, cfg, lim); listener != nil { + log.Debug("appsec: registering the gRPC WAF Event Listener") + dyngo.On(root, listener.onEvent) + } } -func SupportsAddress(addr string) bool { - _, ok := supportedAddresses[addr] - return ok +type wafEventListener struct { + wafHandle *waf.Handle + config *config.Config + actions sharedsec.Actions + addresses map[string]struct{} + limiter limiter.Limiter + wafDiags waf.Diagnostics + once sync.Once +} + +func newWafEventListener(wafHandle *waf.Handle, actions sharedsec.Actions, cfg *config.Config, limiter limiter.Limiter) *wafEventListener { + if wafHandle == nil { + log.Debug("appsec: no WAF Handle available, the gRPC WAF Event Listener will not be registered") + return nil + } + + addresses := listener.FilterAddressSet(supportedAddresses, wafHandle) + if len(addresses) == 0 { + log.Debug("appsec: no supported gRPC address is used by currently loaded WAF rules, the gRPC WAF Event Listener will not be registered") + return nil + } + + return &wafEventListener{ + wafHandle: wafHandle, + config: cfg, + actions: actions, + addresses: addresses, + limiter: limiter, + wafDiags: wafHandle.Diagnostics(), + } } // NewWAFEventListener returns the WAF event listener to register in order // to enable it. -func NewWAFEventListener(handle *waf.Handle, actions sharedsec.Actions, addresses map[string]struct{}, timeout time.Duration, limiter limiter.Limiter) dyngo.EventListener { - var monitorRulesOnce sync.Once // per instantiation - wafDiags := handle.Diagnostics() - - return grpcsec.OnHandlerOperationStart(func(op *grpcsec.HandlerOperation, handlerArgs grpcsec.HandlerOperationArgs) { - // Limit the maximum number of security events, as a streaming RPC could - // receive unlimited number of messages where we could find security events - const maxWAFEventsPerRequest = 10 - var ( - nbEvents atomic.Uint32 - logOnce sync.Once // per request - - events []any - mu sync.Mutex // events mutex - ) - - wafCtx := waf.NewContext(handle) - if wafCtx == nil { - // The WAF event listener got concurrently released - return - } - - // OnUserIDOperationStart happens when appsec.SetUser() is called. We run the WAF and apply actions to - // see if the associated user should be blocked. Since we don't control the execution flow in this case - // (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user. - op.On(sharedsec.OnUserIDOperationStart(func(userIDOp *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) { - values := map[string]any{} - for addr := range addresses { - if addr == UserIDAddr { - values[UserIDAddr] = args.UserID - } +func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types.HandlerOperationArgs) { + // Limit the maximum number of security events, as a streaming RPC could + // receive unlimited number of messages where we could find security events + const maxWAFEventsPerRequest = 10 + var ( + nbEvents atomic.Uint32 + logOnce sync.Once // per request + + events []any + mu sync.Mutex // events mutex + ) + + wafCtx := waf.NewContext(l.wafHandle) + if wafCtx == nil { + // The WAF event listener got concurrently released + return + } + + // OnUserIDOperationStart happens when appsec.SetUser() is called. We run the WAF and apply actions to + // see if the associated user should be blocked. Since we don't control the execution flow in this case + // (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user. + dyngo.On(op, func(userIDOp *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) { + values := make(map[string]any, 1) + for addr := range l.addresses { + if addr == UserIDAddr { + values[UserIDAddr] = args.UserID } - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, timeout) - if wafResult.HasActions() || wafResult.HasEvents() { - for _, id := range wafResult.Actions { - if a, ok := actions[id]; ok && a.Blocking() { - code, err := a.GRPC()(map[string][]string{}) - userIDOp.EmitData(grpcsec.NewMonitoringError(err.Error(), code)) - } + } + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, l.config.WAFTimeout) + if wafResult.HasActions() || wafResult.HasEvents() { + for _, id := range wafResult.Actions { + if a, ok := l.actions[id]; ok && a.Blocking() { + code, err := a.GRPC()(map[string][]string{}) + dyngo.EmitData(userIDOp, types.NewMonitoringError(err.Error(), code)) } - listener.AddSecurityEvents(op, limiter, wafResult.Events) - log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID) - } - })) - - // The same address is used for gRPC and http when it comes to client ip - values := map[string]any{} - for addr := range addresses { - if addr == HTTPClientIPAddr && handlerArgs.ClientIP.IsValid() { - values[HTTPClientIPAddr] = handlerArgs.ClientIP.String() } + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID) } + }) - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, timeout) - if wafResult.HasActions() || wafResult.HasEvents() { - interrupt := listener.ProcessActions(op, actions, wafResult.Actions) - listener.AddSecurityEvents(op, limiter, wafResult.Events) - log.Debug("appsec: WAF detected an attack before executing the request") - if interrupt { - wafCtx.Close() - return - } + // The same address is used for gRPC and http when it comes to client ip + values := make(map[string]any, 1) + for addr := range l.addresses { + if addr == HTTPClientIPAddr && handlerArgs.ClientIP.IsValid() { + values[HTTPClientIPAddr] = handlerArgs.ClientIP.String() } + } + + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, l.config.WAFTimeout) + if wafResult.HasActions() || wafResult.HasEvents() { + interrupt := shared.ProcessActions(op, l.actions, wafResult.Actions) + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + log.Debug("appsec: WAF detected an attack before executing the request") + if interrupt { + wafCtx.Close() + return + } + } - op.On(grpcsec.OnReceiveOperationFinish(func(_ grpcsec.ReceiveOperation, res grpcsec.ReceiveOperationRes) { - if nbEvents.Load() == maxWAFEventsPerRequest { - logOnce.Do(func() { - log.Debug("appsec: ignoring the rpc message due to the maximum number of security events per grpc call reached") - }) - return - } + dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) { + if nbEvents.Load() == maxWAFEventsPerRequest { + logOnce.Do(func() { + log.Debug("appsec: ignoring the rpc message due to the maximum number of security events per grpc call reached") + }) + return + } - // Run the WAF on the rule addresses available in the args - // Note that we don't check if the address is present in the rules - // as we only support one at the moment, so this callback cannot be - // set when the address is not present. - values := waf.RunAddressData{ - Ephemeral: map[string]any{GRPCServerRequestMessage: res.Message}, - } - if md := handlerArgs.Metadata; len(md) > 0 { - values.Persistent = map[string]any{GRPCServerRequestMetadata: md} - } - // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's - // response is not supported at the moment. - wafResult := listener.RunWAF(wafCtx, values, timeout) - - if wafResult.HasEvents() { - log.Debug("appsec: attack detected by the grpc waf") - nbEvents.Inc() - mu.Lock() - defer mu.Unlock() - events = append(events, wafResult.Events...) - } - })) + // Run the WAF on the rule addresses available in the args + // Note that we don't check if the address is present in the rules + // as we only support one at the moment, so this callback cannot be + // set when the address is not present. + values := waf.RunAddressData{ + Ephemeral: map[string]any{GRPCServerRequestMessage: res.Message}, + } + if md := handlerArgs.Metadata; len(md) > 0 { + values.Persistent = map[string]any{GRPCServerRequestMetadata: md} + } + // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's + // response is not supported at the moment. + wafResult := shared.RunWAF(wafCtx, values, l.config.WAFTimeout) + + if wafResult.HasEvents() { + log.Debug("appsec: attack detected by the grpc waf") + nbEvents.Inc() + mu.Lock() + defer mu.Unlock() + events = append(events, wafResult.Events...) + } + }) - op.On(grpcsec.OnHandlerOperationFinish(func(op *grpcsec.HandlerOperation, _ grpcsec.HandlerOperationRes) { - defer wafCtx.Close() - overallRuntimeNs, internalRuntimeNs := wafCtx.TotalRuntime() - listener.AddWAFMonitoringTags(op, wafDiags.Version, overallRuntimeNs, internalRuntimeNs, wafCtx.TotalTimeouts()) + dyngo.OnFinish(op, func(op *types.HandlerOperation, _ types.HandlerOperationRes) { + defer wafCtx.Close() + overallRuntimeNs, internalRuntimeNs := wafCtx.TotalRuntime() + shared.AddWAFMonitoringTags(op, l.wafDiags.Version, overallRuntimeNs, internalRuntimeNs, wafCtx.TotalTimeouts()) - // Log the following metrics once per instantiation of a WAF handle - monitorRulesOnce.Do(func() { - listener.AddRulesMonitoringTags(op, &wafDiags) - op.SetTag(ext.ManualKeep, samplernames.AppSec) - }) + // Log the following metrics once per instantiation of a WAF handle + l.once.Do(func() { + shared.AddRulesMonitoringTags(op, &l.wafDiags) + op.SetTag(ext.ManualKeep, samplernames.AppSec) + }) - listener.AddSecurityEvents(op, limiter, events) - })) + shared.AddSecurityEvents(op, l.limiter, events) }) } diff --git a/internal/appsec/listener/httpsec/http.go b/internal/appsec/listener/httpsec/http.go index 7e1d6b1cb2..eca709185e 100644 --- a/internal/appsec/listener/httpsec/http.go +++ b/internal/appsec/listener/httpsec/http.go @@ -9,16 +9,16 @@ import ( "fmt" "math/rand" "sync" - "time" - internal "github.com/DataDog/appsec-internal-go/appsec" "github.com/DataDog/appsec-internal-go/limiter" waf "github.com/DataDog/go-libddwaf/v2" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/config" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec" - emitter "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" - listener "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec/types" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener" + shared "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "gopkg.in/DataDog/dd-trace-go.v1/internal/samplernames" ) @@ -39,7 +39,7 @@ const ( ) // List of HTTP rule addresses currently supported by the WAF -var supportedAddresses = map[string]struct{}{ +var supportedAddresses = listener.AddressSet{ ServerRequestMethodAddr: {}, ServerRequestRawURIAddr: {}, ServerRequestHeadersNoCookiesAddr: {}, @@ -53,147 +53,188 @@ var supportedAddresses = map[string]struct{}{ UserIDAddr: {}, } -func SupportedAddressCount() int { - return len(supportedAddresses) +// Install registers the HTTP WAF Event Listener on the given root operation. +func Install(wafHandle *waf.Handle, actions sharedsec.Actions, cfg *config.Config, lim limiter.Limiter, root dyngo.Operation) { + if listener := newWafEventListener(wafHandle, actions, cfg, lim); listener != nil { + log.Debug("appsec: registering the HTTP WAF Event Listener") + dyngo.On(root, listener.onEvent) + } } -func SupportsAddress(addr string) bool { - _, ok := supportedAddresses[addr] - return ok +type wafEventListener struct { + wafHandle *waf.Handle + config *config.Config + actions sharedsec.Actions + addresses map[string]struct{} + limiter limiter.Limiter + wafDiags waf.Diagnostics + once sync.Once } -// NewWAFEventListener returns the WAF event listener to register in order to enable it. -func NewWAFEventListener(handle *waf.Handle, actions emitter.Actions, addresses map[string]struct{}, timeout time.Duration, apiSecCfg *internal.APISecConfig, limiter limiter.Limiter) dyngo.EventListener { - var monitorRulesOnce sync.Once // per instantiation - // TODO: port wafDiags to telemetry metrics and logs instead of span tags (ultimately removing them from here hopefully) - wafDiags := handle.Diagnostics() - - return httpsec.OnHandlerOperationStart(func(op *httpsec.Operation, args httpsec.HandlerOperationArgs) { - wafCtx := waf.NewContext(handle) - - if wafCtx == nil { - // The WAF event listener got concurrently released - return - } - - if _, ok := addresses[UserIDAddr]; ok { - // OnUserIDOperationStart happens when appsec.SetUser() is called. We run the WAF and apply actions to - // see if the associated user should be blocked. Since we don't control the execution flow in this case - // (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user. - op.On(emitter.OnUserIDOperationStart(func(operation *emitter.UserIDOperation, args emitter.UserIDOperationArgs) { - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: map[string]any{UserIDAddr: args.UserID}}, timeout) - if wafResult.HasActions() || wafResult.HasEvents() { - listener.ProcessHTTPSDKAction(operation, actions, wafResult.Actions) - listener.AddSecurityEvents(op, limiter, wafResult.Events) - log.Debug("appsec: WAF detected a suspicious user: %s", args.UserID) - } - })) - } +func newWafEventListener(wafHandle *waf.Handle, actions sharedsec.Actions, cfg *config.Config, limiter limiter.Limiter) *wafEventListener { + if wafHandle == nil { + log.Debug("appsec: no WAF Handle available, the HTTP WAF Event Listener will not be registered") + return nil + } + + addresses := listener.FilterAddressSet(supportedAddresses, wafHandle) + if len(addresses) == 0 { + log.Debug("appsec: no supported HTTP address is used by currently loaded WAF rules, the HTTP WAF Event Listener will not be registered") + return nil + } + + return &wafEventListener{ + wafHandle: wafHandle, + config: cfg, + actions: actions, + addresses: addresses, + limiter: limiter, + wafDiags: wafHandle.Diagnostics(), + } +} - values := make(map[string]any, 8) - for addr := range addresses { - switch addr { - case HTTPClientIPAddr: - if args.ClientIP.IsValid() { - values[HTTPClientIPAddr] = args.ClientIP.String() - } - case ServerRequestMethodAddr: - values[ServerRequestMethodAddr] = args.Method - case ServerRequestRawURIAddr: - values[ServerRequestRawURIAddr] = args.RequestURI - case ServerRequestHeadersNoCookiesAddr: - if headers := args.Headers; headers != nil { - values[ServerRequestHeadersNoCookiesAddr] = headers - } - case ServerRequestCookiesAddr: - if cookies := args.Cookies; cookies != nil { - values[ServerRequestCookiesAddr] = cookies - } - case ServerRequestQueryAddr: - if query := args.Query; query != nil { - values[ServerRequestQueryAddr] = query - } - case ServerRequestPathParamsAddr: - if pathParams := args.PathParams; pathParams != nil { - values[ServerRequestPathParamsAddr] = pathParams - } +// NewWAFEventListener returns the WAF event listener to register in order to enable it. +func (l *wafEventListener) onEvent(op *types.Operation, args types.HandlerOperationArgs) { + wafCtx := waf.NewContext(l.wafHandle) + if wafCtx == nil { + // The WAF event listener got concurrently released + return + } + + if _, ok := l.addresses[UserIDAddr]; ok { + // OnUserIDOperationStart happens when appsec.SetUser() is called. We run the WAF and apply actions to + // see if the associated user should be blocked. Since we don't control the execution flow in this case + // (SetUser is SDK), we delegate the responsibility of interrupting the handler to the user. + dyngo.On(op, func(operation *sharedsec.UserIDOperation, args sharedsec.UserIDOperationArgs) { + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: map[string]any{UserIDAddr: args.UserID}}, l.config.WAFTimeout) + if wafResult.HasActions() || wafResult.HasEvents() { + processHTTPSDKAction(operation, l.actions, wafResult.Actions) + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + log.Debug("appsec: WAF detected a suspicious user: %s", args.UserID) + } + }) + } + + values := make(map[string]any, 8) + for addr := range l.addresses { + switch addr { + case HTTPClientIPAddr: + if args.ClientIP.IsValid() { + values[HTTPClientIPAddr] = args.ClientIP.String() + } + case ServerRequestMethodAddr: + values[ServerRequestMethodAddr] = args.Method + case ServerRequestRawURIAddr: + values[ServerRequestRawURIAddr] = args.RequestURI + case ServerRequestHeadersNoCookiesAddr: + if headers := args.Headers; headers != nil { + values[ServerRequestHeadersNoCookiesAddr] = headers + } + case ServerRequestCookiesAddr: + if cookies := args.Cookies; cookies != nil { + values[ServerRequestCookiesAddr] = cookies + } + case ServerRequestQueryAddr: + if query := args.Query; query != nil { + values[ServerRequestQueryAddr] = query + } + case ServerRequestPathParamsAddr: + if pathParams := args.PathParams; pathParams != nil { + values[ServerRequestPathParamsAddr] = pathParams } } - if canExtractSchemas(apiSecCfg) { - // This address will be passed as persistent. The WAF will keep it in store and trigger schema extraction - // for each run. - values["waf.context.processor"] = map[string]any{"extract-schema": true} + } + if l.canExtractSchemas() { + // This address will be passed as persistent. The WAF will keep it in store and trigger schema extraction + // for each run. + values["waf.context.processor"] = map[string]any{"extract-schema": true} + } + + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, l.config.WAFTimeout) + for tag, value := range wafResult.Derivatives { + op.AddSerializableTag(tag, value) + } + if wafResult.HasActions() || wafResult.HasEvents() { + interrupt := shared.ProcessActions(op, l.actions, wafResult.Actions) + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + log.Debug("appsec: WAF detected an attack before executing the request") + if interrupt { + wafCtx.Close() + return } + } - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, timeout) - for tag, value := range wafResult.Derivatives { - op.AddSerializableTag(tag, value) - } - if wafResult.HasActions() || wafResult.HasEvents() { - interrupt := listener.ProcessActions(op, actions, wafResult.Actions) - listener.AddSecurityEvents(op, limiter, wafResult.Events) - log.Debug("appsec: WAF detected an attack before executing the request") - if interrupt { - wafCtx.Close() - return + if _, ok := l.addresses[ServerRequestBodyAddr]; ok { + dyngo.On(op, func(sdkBodyOp *types.SDKBodyOperation, args types.SDKBodyOperationArgs) { + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: map[string]any{ServerRequestBodyAddr: args.Body}}, l.config.WAFTimeout) + for tag, value := range wafResult.Derivatives { + op.AddSerializableTag(tag, value) } - } - - if _, ok := addresses[ServerRequestBodyAddr]; ok { - op.On(httpsec.OnSDKBodyOperationStart(func(sdkBodyOp *httpsec.SDKBodyOperation, args httpsec.SDKBodyOperationArgs) { - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: map[string]any{ServerRequestBodyAddr: args.Body}}, timeout) - for tag, value := range wafResult.Derivatives { - op.AddSerializableTag(tag, value) - } - if wafResult.HasActions() || wafResult.HasEvents() { - listener.ProcessHTTPSDKAction(sdkBodyOp, actions, wafResult.Actions) - listener.AddSecurityEvents(op, limiter, wafResult.Events) - log.Debug("appsec: WAF detected a suspicious request body") - } - })) - } + if wafResult.HasActions() || wafResult.HasEvents() { + processHTTPSDKAction(sdkBodyOp, l.actions, wafResult.Actions) + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + log.Debug("appsec: WAF detected a suspicious request body") + } + }) + } - op.On(httpsec.OnHandlerOperationFinish(func(op *httpsec.Operation, res httpsec.HandlerOperationRes) { - defer wafCtx.Close() + dyngo.OnFinish(op, func(op *types.Operation, res types.HandlerOperationRes) { + defer wafCtx.Close() - values = make(map[string]any, 2) - if _, ok := addresses[ServerResponseStatusAddr]; ok { - // serverResponseStatusAddr is a string address, so we must format the status code... - values[ServerResponseStatusAddr] = fmt.Sprintf("%d", res.Status) - } + values = make(map[string]any, 2) + if _, ok := l.addresses[ServerResponseStatusAddr]; ok { + // serverResponseStatusAddr is a string address, so we must format the status code... + values[ServerResponseStatusAddr] = fmt.Sprintf("%d", res.Status) + } - if _, ok := addresses[ServerResponseHeadersNoCookiesAddr]; ok && res.Headers != nil { - values[ServerResponseHeadersNoCookiesAddr] = res.Headers - } + if _, ok := l.addresses[ServerResponseHeadersNoCookiesAddr]; ok && res.Headers != nil { + values[ServerResponseHeadersNoCookiesAddr] = res.Headers + } - // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's - // response is not supported at the moment. - wafResult := listener.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, timeout) + // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's + // response is not supported at the moment. + wafResult := shared.RunWAF(wafCtx, waf.RunAddressData{Persistent: values}, l.config.WAFTimeout) - // Add WAF metrics. - overallRuntimeNs, internalRuntimeNs := wafCtx.TotalRuntime() - listener.AddWAFMonitoringTags(op, wafDiags.Version, overallRuntimeNs, internalRuntimeNs, wafCtx.TotalTimeouts()) + // Add WAF metrics. + overallRuntimeNs, internalRuntimeNs := wafCtx.TotalRuntime() + shared.AddWAFMonitoringTags(op, l.wafDiags.Version, overallRuntimeNs, internalRuntimeNs, wafCtx.TotalTimeouts()) - // Add the following metrics once per instantiation of a WAF handle - monitorRulesOnce.Do(func() { - listener.AddRulesMonitoringTags(op, &wafDiags) - op.SetTag(ext.ManualKeep, samplernames.AppSec) - }) + // Add the following metrics once per instantiation of a WAF handle + l.once.Do(func() { + shared.AddRulesMonitoringTags(op, &l.wafDiags) + op.SetTag(ext.ManualKeep, samplernames.AppSec) + }) - // Log the attacks if any - if wafResult.HasEvents() { - log.Debug("appsec: attack detected by the waf") - listener.AddSecurityEvents(op, limiter, wafResult.Events) - } - for tag, value := range wafResult.Derivatives { - op.AddSerializableTag(tag, value) - } - })) + // Log the attacks if any + if wafResult.HasEvents() { + log.Debug("appsec: attack detected by the waf") + shared.AddSecurityEvents(op, l.limiter, wafResult.Events) + } + for tag, value := range wafResult.Derivatives { + op.AddSerializableTag(tag, value) + } }) } // canExtractSchemas checks that API Security is enabled and that sampling rate // allows extracting schemas -func canExtractSchemas(cfg *internal.APISecConfig) bool { - return cfg != nil && cfg.Enabled && cfg.SampleRate >= rand.Float64() +func (l *wafEventListener) canExtractSchemas() bool { + return l.config.APISec.Enabled && l.config.APISec.SampleRate >= rand.Float64() +} + +// processHTTPSDKAction does two things: +// - send actions to the parent operation's data listener, for their handlers to be executed after the user handler +// - send an error to the current operation's data listener (created by an SDK call), to signal users to interrupt +// their handler. +func processHTTPSDKAction(op dyngo.Operation, actions sharedsec.Actions, actionIds []string) { + for _, id := range actionIds { + if action, ok := actions[id]; ok { + if op.Parent() != nil { + dyngo.EmitData(op, action) // Send the action so that the handler gets executed + } + if action.Blocking() { // Send the error to be returned by the SDK + dyngo.EmitData(op, types.NewMonitoringError("Request blocked")) // Send error + } + } + } } diff --git a/internal/appsec/listener/listener.go b/internal/appsec/listener/listener.go index 00176e7ceb..f1668152ed 100644 --- a/internal/appsec/listener/listener.go +++ b/internal/appsec/listener/listener.go @@ -8,5 +8,24 @@ // types found in gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter. package listener +import waf "github.com/DataDog/go-libddwaf/v2" + // ContextKey is used as a key to store operations in the request's context (gRPC/HTTP) type ContextKey struct{} + +// AddressSet is a set of WAF addresses. +type AddressSet map[string]struct{} + +// FilterAddressSet filters the supplied `supported` address set to only include +// entries referenced by the supplied waf.Handle. +func FilterAddressSet(supported AddressSet, handle *waf.Handle) AddressSet { + result := make(AddressSet, len(supported)) + + for _, addr := range handle.Addresses() { + if _, found := supported[addr]; found { + result[addr] = struct{}{} + } + } + + return result +} diff --git a/internal/appsec/listener/sharedsec/shared.go b/internal/appsec/listener/sharedsec/shared.go index 8915fb236c..126a95065e 100644 --- a/internal/appsec/listener/sharedsec/shared.go +++ b/internal/appsec/listener/sharedsec/shared.go @@ -12,7 +12,6 @@ import ( "github.com/DataDog/appsec-internal-go/limiter" waf "github.com/DataDog/go-libddwaf/v2" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/httpsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/trace" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" @@ -84,26 +83,9 @@ func AddWAFMonitoringTags(th trace.TagSetter, rulesVersion string, overallRuntim func ProcessActions(op dyngo.Operation, actions sharedsec.Actions, actionIds []string) (interrupt bool) { for _, id := range actionIds { if a, ok := actions[id]; ok { - op.EmitData(actions[id]) + dyngo.EmitData(op, actions[id]) interrupt = interrupt || a.Blocking() } } return interrupt } - -// ProcessHTTPSDKAction does two things: -// - send actions to the parent operation's data listener, for their handlers to be executed after the user handler -// - send an error to the current operation's data listener (created by an SDK call), to signal users to interrupt -// their handler. -func ProcessHTTPSDKAction(op dyngo.Operation, actions sharedsec.Actions, actionIds []string) { - for _, id := range actionIds { - if action, ok := actions[id]; ok { - if op.Parent() != nil { - op.Parent().EmitData(action) // Send the action so that the handler gets executed - } - if action.Blocking() { // Send the error to be returned by the SDK - op.EmitData(httpsec.NewMonitoringError("Request blocked")) // Send error - } - } - } -} diff --git a/internal/appsec/waf.go b/internal/appsec/waf.go index 284ceed7a5..8e74bfca89 100644 --- a/internal/appsec/waf.go +++ b/internal/appsec/waf.go @@ -6,30 +6,14 @@ package appsec import ( - "errors" - "github.com/DataDog/appsec-internal-go/limiter" waf "github.com/DataDog/go-libddwaf/v2" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/config" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/emitter/sharedsec" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/graphqlsec" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/grpcsec" - "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/listener/httpsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" ) -const ( - eventRulesVersionTag = "_dd.appsec.event_rules.version" - eventRulesErrorsTag = "_dd.appsec.event_rules.errors" - eventRulesLoadedTag = "_dd.appsec.event_rules.loaded" - eventRulesFailedTag = "_dd.appsec.event_rules.error_count" - wafDurationTag = "_dd.appsec.waf.duration" - wafDurationExtTag = "_dd.appsec.waf.duration_ext" - wafTimeoutTag = "_dd.appsec.waf.timeouts" - wafVersionTag = "_dd.appsec.waf.version" -) - type wafHandle struct { *waf.Handle // actions are tightly link to a ruleset, which is linked to a waf handle @@ -50,15 +34,9 @@ func (a *appsec) swapWAF(rules config.RulesFragment) (err error) { } }() - listeners, err := newWAFEventListeners(newHandle, a.cfg, a.limiter) - if err != nil { - return err - } - - // Register the event listeners now that we know that the new handle is valid newRoot := dyngo.NewRootOperation() - for _, l := range listeners { - newRoot.On(l) + for _, fn := range wafEventListeners { + fn(newHandle.Handle, newHandle.actions, a.cfg, a.limiter, newRoot) } // Hot-swap dyngo's root operation @@ -115,56 +93,17 @@ func newWAFHandle(rules config.RulesFragment, cfg *config.Config) (*wafHandle, e }, err } -func newWAFEventListeners(waf *wafHandle, cfg *config.Config, l limiter.Limiter) (listeners []dyngo.EventListener, err error) { - // Check if there are addresses in the rule - ruleAddresses := waf.Addresses() - if len(ruleAddresses) == 0 { - return nil, errors.New("no addresses found in the rule") - } +type wafEventListener func(*waf.Handle, sharedsec.Actions, *config.Config, limiter.Limiter, dyngo.Operation) - // Check which addresses are supported by what listener - graphQLAddresses := make(map[string]struct{}, graphqlsec.SupportedAddressCount()) - grpcAddresses := make(map[string]struct{}, grpcsec.SupportedAddressCount()) - httpAddresses := make(map[string]struct{}, httpsec.SupportedAddressCount()) - notSupported := make([]string, 0, len(ruleAddresses)) - for _, address := range ruleAddresses { - supported := false - if graphqlsec.SupportsAddress(address) { - graphQLAddresses[address] = struct{}{} - supported = true - } - if grpcsec.SupportsAddress(address) { - grpcAddresses[address] = struct{}{} - supported = true - } - if httpsec.SupportsAddress(address) { - httpAddresses[address] = struct{}{} - supported = true - } - if !supported { - notSupported = append(notSupported, address) - } - } - - if len(notSupported) > 0 { - log.Debug("appsec: the addresses present in the rules are partially supported: not supported=%v", notSupported) - } - - // Register the WAF event listeners - if len(graphQLAddresses) > 0 { - log.Debug("appsec: creating the GraphQL waf event listener of the rules addresses %v", graphQLAddresses) - listeners = append(listeners, graphqlsec.NewWAFEventListener(waf.Handle, waf.actions, graphQLAddresses, cfg.WAFTimeout, l)) - } - - if len(grpcAddresses) > 0 { - log.Debug("appsec: creating the grpc waf event listener of the rules addresses %v", grpcAddresses) - listeners = append(listeners, grpcsec.NewWAFEventListener(waf.Handle, waf.actions, grpcAddresses, cfg.WAFTimeout, l)) - } - - if len(httpAddresses) > 0 { - log.Debug("appsec: creating http waf event listener of the rules addresses %v", httpAddresses) - listeners = append(listeners, httpsec.NewWAFEventListener(waf.Handle, waf.actions, httpAddresses, cfg.WAFTimeout, &cfg.APISec, l)) - } +// wafEventListeners is the global list of event listeners registered by contribs at init time. This +// is thread-safe assuming all writes (via AddWAFEventListener) are performed within `init` +// functions; so this is written to only during initialization, and is read from concurrently only +// during runtime when no writes are happening anymore. +var wafEventListeners []wafEventListener - return listeners, nil +// AddWAFEventListener adds a new WAF event listener to be registered whenever a new root operation +// is created. The normal way to use this is to call it from a `func init() {}` so that it is +// guaranteed to have happened before any listened to event may be emitted. +func AddWAFEventListener(fn wafEventListener) { + wafEventListeners = append(wafEventListeners, fn) }