-
Notifications
You must be signed in to change notification settings - Fork 2
/
ctxlogger.go
138 lines (120 loc) · 4.31 KB
/
ctxlogger.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package ctxlogger
import (
"context"
"encoding/json"
"github.com/Azure/aks-middleware/requestid"
log "log/slog"
loggable "buf.build/gen/go/service-hub/loggable/protocolbuffers/go/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)
// ExtractFunction extracts information from the ctx and/or the request and put it in the logger.
// This function is called before the application's handler is called so that it can add more context
// to the logger.
type ExtractFunction func(ctx context.Context, req any, info *grpc.UnaryServerInfo, logger *log.Logger) *log.Logger
type loggerKeyType int
const (
loggerKey loggerKeyType = iota
)
func WithLogger(ctx context.Context, logger *log.Logger) context.Context {
return context.WithValue(ctx, loggerKey, logger)
}
func GetLogger(ctx context.Context) *log.Logger {
logger := log.Default().With("src", "self gen, not available in ctx")
if ctx == nil {
return logger
}
if ctxlogger, ok := ctx.Value(loggerKey).(*log.Logger); ok {
return ctxlogger
}
return logger
}
// UnaryServerInterceptor returns a UnaryServerInterceptor.
// extractFunction can be nil if the defaultExtractFunction() is good enough.
// extractFunction is for ctx or request specific information.
// For information that doesn't change with ctx/request, pass the information via logger.
// The first registerred interceptor will be called first.
// Need to register requestid first to add request-id.
// Then the logger can get the request-id.
func UnaryServerInterceptor(logger *log.Logger, extractFunction ExtractFunction) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp any, err error) {
l := logger
if extractFunction != nil {
l = extractFunction(ctx, req, info, l)
} else {
l = defaultExtractFunction(ctx, req, info, l)
}
l = l.With(requestContentLogKey, FilterLogs(req))
ctx = WithLogger(ctx, l)
// log.Print("logger ctx: ", ctx)
return handler(ctx, req)
}
}
const (
methodLogKey = "method"
requestContentLogKey = "request"
)
func defaultExtractFunction(ctx context.Context, req any, info *grpc.UnaryServerInfo, logger *log.Logger) *log.Logger {
l := logger
l = l.With(methodLogKey, info.FullMethod)
l = l.With(requestid.RequestIDLogKey, requestid.GetRequestID(ctx))
return l
}
func filterLoggableFields(currentMap map[string]interface{}, message protoreflect.Message) map[string]interface{} {
// Check if the map or the message is nil
if currentMap == nil || message == nil {
return currentMap
}
for name, value := range currentMap {
// Get the field descriptor by name
fd := message.Descriptor().Fields().ByName(protoreflect.Name(name))
// Check if the field descriptor is nil
if fd == nil {
continue
}
opts := fd.Options()
fdOpts := opts.(*descriptorpb.FieldOptions)
loggable := proto.GetExtension(fdOpts, loggable.E_Loggable)
// Delete the field from the map if it is not loggable
if !loggable.(bool) {
delete(currentMap, name)
continue
}
// Check if the value is another map[string]interface{}
if subMap, ok := value.(map[string]interface{}); ok {
// Check if its a simple map or one containing messages
if fd.Message() != nil && !fd.Message().IsMapEntry() {
// Get the sub-message for the field
subMessage := message.Get(fd).Message()
// Call the helper function recursively on the subMap and subMessage
currentMap[name] = filterLoggableFields(subMap, subMessage)
}
}
}
return currentMap
}
func FilterLogs(req any) map[string]interface{} {
in, ok := req.(proto.Message)
var reqPayload map[string]interface{}
if ok {
// Get the protoreflect.Message interface for the message
message := in.ProtoReflect()
// Marshal the message to JSON bytes
jsonBytes, err := protojson.Marshal(message.Interface().(protoreflect.ProtoMessage))
if err != nil {
log.Error(err.Error())
}
// Unmarshal the JSON bytes to a map[string]interface{}
err = json.Unmarshal(jsonBytes, &reqPayload)
if err != nil {
log.Error(err.Error())
}
// Filter out the fields that are not loggable using the helper function
reqPayload = filterLoggableFields(reqPayload, message)
}
return reqPayload
}