/
interceptors.go
292 lines (260 loc) · 7.9 KB
/
interceptors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
package aogrpc
import (
"fmt"
"io"
fp "path/filepath"
"strings"
"sync"
"time"
"golang.org/x/net/context"
"github.com/appoptics/appoptics-apm-go/v1/ao"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func actionFromMethod(method string) string {
mParts := strings.Split(method, "/")
return mParts[len(mParts)-1]
}
// StackTracer is a copy of the stackTracer interface of pkg/errors.
//
// This may be fragile as stackTracer is not imported, just try our best though.
type StackTracer interface {
StackTrace() errors.StackTrace
}
func getErrClass(err error) string {
if st, ok := err.(StackTracer); ok {
pkg, e := getTopFramePkg(st)
if e == nil {
return pkg
}
}
// seems we cannot do anything else, so just return the fallback value
return "error"
}
var (
errNilStackTracer = errors.New("nil stackTracer pointer")
errEmptyStackTrace = errors.New("empty stack trace")
errGetTopFramePkg = errors.New("failed to get top frame package name")
)
func getTopFramePkg(st StackTracer) (string, error) {
if st == nil {
return "", errNilStackTracer
}
trace := st.StackTrace()
if len(trace) == 0 {
return "", errEmptyStackTrace
}
fs := fmt.Sprintf("%+s", trace[0])
// it is fragile to use this hard-coded separator
// see: https://github.com/pkg/errors/blob/30136e27e2ac8d167177e8a583aa4c3fea5be833/stack.go#L63
frames := strings.Split(fs, "\n\t")
if len(frames) != 2 {
return "", errGetTopFramePkg
}
return fp.Base(fp.Dir(frames[1])), nil
}
func getFirstValFromMd(md metadata.MD, key string) string {
var v string
if xt, ok := md[key]; ok {
v = xt[0]
} else if xt, ok = md[strings.ToLower(key)]; ok {
v = xt[0]
}
return v
}
func tracingContext(ctx context.Context, serverName string, methodName string, statusCode *int) (context.Context, ao.Trace) {
action := actionFromMethod(methodName)
xtID := ""
opt := ""
signature := ""
md, ok := metadata.FromIncomingContext(ctx)
if ok {
xtID = getFirstValFromMd(md, ao.HTTPHeaderName)
opt = getFirstValFromMd(md, ao.HTTPHeaderXTraceOptions)
signature = getFirstValFromMd(md, ao.HTTPHeaderXTraceOptionsSignature)
}
t := ao.NewTraceWithOptions(serverName, ao.SpanOptions{
ContextOptions: ao.ContextOptions{
MdStr: xtID,
URL: methodName,
XTraceOptions: opt,
XTraceOptionsSignature: signature,
CB: func() ao.KVMap {
kvs := ao.KVMap{
"Method": "POST",
"Controller": serverName,
"Action": action,
"URL": methodName,
"Status": statusCode,
}
return kvs
},
}})
t.SetMethod("POST")
t.SetTransactionName(serverName + "." + action)
t.SetStartTime(time.Now())
return ao.NewContext(ctx, t), t
}
// UnaryServerInterceptor returns an interceptor that traces gRPC unary server RPCs using AppOptics.
// If the client is using UnaryClientInterceptor, the distributed trace's context will be read from the client.
func UnaryServerInterceptor(serverName string) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
var err error
var resp interface{}
var statusCode = 200
var t ao.Trace
ctx, t = tracingContext(ctx, serverName, info.FullMethod, &statusCode)
defer func() {
t.SetStatus(statusCode)
ao.EndTrace(ctx)
}()
resp, err = handler(ctx, req)
if err != nil {
statusCode = 500
ao.Error(ctx, getErrClass(err), err.Error())
}
return resp, err
}
}
// wrappedServerStream from the grpc_middleware project
type wrappedServerStream struct {
grpc.ServerStream
WrappedContext context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.WrappedContext
}
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
if existing, ok := stream.(*wrappedServerStream); ok {
return existing
}
return &wrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()}
}
// StreamServerInterceptor returns an interceptor that traces gRPC streaming server RPCs using AppOptics.
// Each server span starts with the first message and ends when all request and response messages have finished streaming.
func StreamServerInterceptor(serverName string) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var err error
var statusCode = 200
newCtx, t := tracingContext(stream.Context(), serverName, info.FullMethod, &statusCode)
defer func() {
t.SetStatus(statusCode)
ao.EndTrace(newCtx)
}()
// if lg.IsDebug() {
// sp := ao.FromContext(newCtx)
// lg.Debug("server stream starting", "xtrace", sp.MetadataString())
// }
wrappedStream := wrapServerStream(stream)
wrappedStream.WrappedContext = newCtx
err = handler(srv, wrappedStream)
if err == io.EOF {
return nil
} else if err != nil {
statusCode = 500
ao.Error(newCtx, getErrClass(err), err.Error())
}
return err
}
}
// UnaryClientInterceptor returns an interceptor that traces a unary RPC from a gRPC client to a server using
// AppOptics, by propagating the distributed trace's context from client to server using gRPC metadata.
func UnaryClientInterceptor(target string, serviceName string) grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, resp interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
action := actionFromMethod(method)
span := ao.BeginRPCSpan(ctx, action, "grpc", serviceName, target)
defer span.End()
xtID := span.MetadataString()
if len(xtID) > 0 {
ctx = metadata.AppendToOutgoingContext(ctx, ao.HTTPHeaderName, xtID)
}
err := invoker(ctx, method, req, resp, cc, opts...)
if err != nil {
span.Error(getErrClass(err), err.Error())
return err
}
return nil
}
}
// StreamClientInterceptor returns an interceptor that traces a streaming RPC from a gRPC client to a server using
// AppOptics, by propagating the distributed trace's context from client to server using gRPC metadata.
// The client span starts with the first message and ends when all request and response messages have finished streaming.
func StreamClientInterceptor(target string, serviceName string) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
action := actionFromMethod(method)
span := ao.BeginRPCSpan(ctx, action, "grpc", serviceName, target)
xtID := span.MetadataString()
// lg.Debug("stream client interceptor", "x-trace", xtID)
if len(xtID) > 0 {
ctx = metadata.AppendToOutgoingContext(ctx, ao.HTTPHeaderName, xtID)
}
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
closeSpan(span, err)
return nil, err
}
return &tracedClientStream{ClientStream: clientStream, span: span}, nil
}
}
type tracedClientStream struct {
grpc.ClientStream
mu sync.Mutex
closed bool
span ao.Span
}
func (s *tracedClientStream) Header() (metadata.MD, error) {
h, err := s.ClientStream.Header()
if err != nil {
s.closeSpan(err)
}
return h, err
}
func (s *tracedClientStream) SendMsg(m interface{}) error {
err := s.ClientStream.SendMsg(m)
if err != nil {
s.closeSpan(err)
}
return err
}
func (s *tracedClientStream) CloseSend() error {
err := s.ClientStream.CloseSend()
if err != nil {
s.closeSpan(err)
}
return err
}
func (s *tracedClientStream) RecvMsg(m interface{}) error {
err := s.ClientStream.RecvMsg(m)
if err != nil {
s.closeSpan(err)
}
return err
}
func (s *tracedClientStream) closeSpan(err error) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.closed {
closeSpan(s.span, err)
s.closed = true
}
}
func closeSpan(span ao.Span, err error) {
// lg.Debug("closing span", "err", err.Error())
if err != nil && err != io.EOF {
span.Error(getErrClass(err), err.Error())
}
span.End()
}