/
conn.go
324 lines (280 loc) · 9.24 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
// Package grpcext allows gRPC requests collecting stats info.
package grpcext
import (
"context"
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"github.com/ChipArtem/k6/lib"
"github.com/ChipArtem/k6/metrics"
protov1 "github.com/golang/protobuf/proto" //nolint:staticcheck,nolintlint // this is the old v1 version
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
grpcstats "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
)
// Request represents a gRPC request.
type Request struct {
MethodDescriptor protoreflect.MethodDescriptor
TagsAndMeta *metrics.TagsAndMeta
Message []byte
}
// Response represents a gRPC response.
type Response struct {
Message interface{}
Error interface{}
Headers map[string][]string
Trailers map[string][]string
Status codes.Code
}
type clientConnCloser interface {
grpc.ClientConnInterface
Close() error
}
// Conn is a gRPC client connection.
type Conn struct {
raw clientConnCloser
}
// DefaultOptions generates an option set
// with common options for requests from a VU.
func DefaultOptions(getState func() *lib.State) []grpc.DialOption {
dialer := func(ctx context.Context, addr string) (net.Conn, error) {
return getState().Dialer.DialContext(ctx, "tcp", addr)
}
return []grpc.DialOption{
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(),
grpc.WithStatsHandler(statsHandler{getState: getState}),
grpc.WithContextDialer(dialer),
}
}
// Dial establish a gRPC connection.
func Dial(ctx context.Context, addr string, options ...grpc.DialOption) (*Conn, error) {
conn, err := grpc.DialContext(ctx, addr, options...)
if err != nil {
return nil, err
}
return &Conn{
raw: conn,
}, nil
}
// Reflect returns using the reflection the FileDescriptorSet describing the service.
func (c *Conn) Reflect(ctx context.Context) (*descriptorpb.FileDescriptorSet, error) {
rc := reflectionClient{Conn: c.raw}
return rc.Reflect(ctx)
}
// Invoke executes a unary gRPC request.
func (c *Conn) Invoke(
ctx context.Context,
url string,
md metadata.MD,
req Request,
opts ...grpc.CallOption,
) (*Response, error) {
if url == "" {
return nil, fmt.Errorf("url is required")
}
if req.MethodDescriptor == nil {
return nil, fmt.Errorf("request method descriptor is required")
}
if len(req.Message) == 0 {
return nil, fmt.Errorf("request message is required")
}
ctx = metadata.NewOutgoingContext(ctx, md)
reqdm := dynamicpb.NewMessage(req.MethodDescriptor.Input())
if err := protojson.Unmarshal(req.Message, reqdm); err != nil {
return nil, fmt.Errorf("unable to serialise request object to protocol buffer: %w", err)
}
ctx = withRPCState(ctx, &rpcState{tagsAndMeta: req.TagsAndMeta})
resp := dynamicpb.NewMessage(req.MethodDescriptor.Output())
header, trailer := metadata.New(nil), metadata.New(nil)
copts := make([]grpc.CallOption, 0, len(opts)+2)
copts = append(copts, opts...)
copts = append(copts, grpc.Header(&header), grpc.Trailer(&trailer))
err := c.raw.Invoke(ctx, url, reqdm, resp, copts...)
response := Response{
Headers: header,
Trailers: trailer,
}
marshaler := protojson.MarshalOptions{EmitUnpopulated: true}
if err != nil {
sterr := status.Convert(err)
response.Status = sterr.Code()
// (rogchap) when you access a JSON property in goja, you are actually accessing the underling
// Go type (struct, map, slice etc); because these are dynamic messages the Unmarshaled JSON does
// not map back to a "real" field or value (as a normal Go type would). If we don't marshal and then
// unmarshal back to a map, you will get "undefined" when accessing JSON properties, even when
// JSON.Stringify() shows the object to be correctly present.
raw, _ := marshaler.Marshal(sterr.Proto())
errMsg := make(map[string]interface{})
_ = json.Unmarshal(raw, &errMsg)
response.Error = errMsg
}
if resp != nil {
// (rogchap) there is a lot of marshaling/unmarshaling here, but if we just pass the dynamic message
// the default Marshaller would be used, which would strip any zero/default values from the JSON.
// eg. given this message:
// message Point {
// double x = 1;
// double y = 2;
// double z = 3;
// }
// and a value like this:
// msg := Point{X: 6, Y: 4, Z: 0}
// would result in JSON output:
// {"x":6,"y":4}
// rather than the desired:
// {"x":6,"y":4,"z":0}
raw, _ := marshaler.Marshal(resp)
var msg interface{}
_ = json.Unmarshal(raw, &msg)
response.Message = msg
}
return &response, nil
}
// Close closes the underhood connection.
func (c *Conn) Close() error {
return c.raw.Close()
}
type statsHandler struct {
getState func() *lib.State
}
// TagConn implements the grpcstats.Handler interface
func (statsHandler) TagConn(ctx context.Context, _ *grpcstats.ConnTagInfo) context.Context { // noop
return ctx
}
// HandleConn implements the grpcstats.Handler interface
func (statsHandler) HandleConn(context.Context, grpcstats.ConnStats) {
// noop
}
// TagRPC implements the grpcstats.Handler interface
func (statsHandler) TagRPC(ctx context.Context, _ *grpcstats.RPCTagInfo) context.Context {
// noop
return ctx
}
// HandleRPC implements the grpcstats.Handler interface
func (h statsHandler) HandleRPC(ctx context.Context, stat grpcstats.RPCStats) {
state := h.getState()
stateRPC := getRPCState(ctx) //nolint:ifshort
// If the request is done by the reflection handler then the tags will be
// nil. In this case, we can reuse the VU.State's Tags.
if stateRPC == nil {
// TODO: investigate this more, there has to be a way to fix it :/
ctm := state.Tags.GetCurrentValues()
stateRPC = &rpcState{tagsAndMeta: &ctm}
}
switch s := stat.(type) {
case *grpcstats.OutHeader:
// TODO: figure out something better, e.g. via TagConn() or TagRPC()?
if state.Options.SystemTags.Has(metrics.TagIP) && s.RemoteAddr != nil {
if ip, _, err := net.SplitHostPort(s.RemoteAddr.String()); err == nil {
stateRPC.tagsAndMeta.SetSystemTagOrMeta(metrics.TagIP, ip)
}
}
case *grpcstats.End:
if state.Options.SystemTags.Has(metrics.TagStatus) {
stateRPC.tagsAndMeta.SetSystemTagOrMeta(metrics.TagStatus, strconv.Itoa(int(status.Code(s.Error))))
}
metrics.PushIfNotDone(ctx, state.Samples, metrics.Sample{
TimeSeries: metrics.TimeSeries{
Metric: state.BuiltinMetrics.GRPCReqDuration,
Tags: stateRPC.tagsAndMeta.Tags,
},
Time: s.EndTime,
Metadata: stateRPC.tagsAndMeta.Metadata,
Value: metrics.D(s.EndTime.Sub(s.BeginTime)),
})
}
// (rogchap) Re-using --http-debug flag as gRPC is technically still HTTP
if state.Options.HTTPDebug.String != "" {
logger := state.Logger.WithField("source", "http-debug")
httpDebugOption := state.Options.HTTPDebug.String
DebugStat(logger, stat, httpDebugOption)
}
}
// DebugStat prints debugging information based on RPCStats.
func DebugStat(logger logrus.FieldLogger, stat grpcstats.RPCStats, httpDebugOption string) {
switch s := stat.(type) {
case *grpcstats.OutHeader:
logger.Infof("Out Header:\nFull Method: %s\nRemote Address: %s\n%s\n",
s.FullMethod, s.RemoteAddr, formatMetadata(s.Header))
case *grpcstats.OutTrailer:
if len(s.Trailer) > 0 {
logger.Infof("Out Trailer:\n%s\n", formatMetadata(s.Trailer))
}
case *grpcstats.OutPayload:
if httpDebugOption == "full" {
logger.Infof("Out Payload:\nWire Length: %d\nSent Time: %s\n%s\n\n",
s.WireLength, s.SentTime, formatPayload(s.Payload))
}
case *grpcstats.InHeader:
if len(s.Header) > 0 {
logger.Infof("In Header:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Header))
}
case *grpcstats.InTrailer:
if len(s.Trailer) > 0 {
logger.Infof("In Trailer:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Trailer))
}
case *grpcstats.InPayload:
if httpDebugOption == "full" {
logger.Infof("In Payload:\nWire Length: %d\nReceived Time: %s\n%s\n\n",
s.WireLength, s.RecvTime, formatPayload(s.Payload))
}
}
}
func formatMetadata(md metadata.MD) string {
var sb strings.Builder
for k, v := range md {
sb.WriteString(k)
sb.WriteString(": ")
sb.WriteString(strings.Join(v, ", "))
sb.WriteRune('\n')
}
return sb.String()
}
func formatPayload(payload interface{}) string {
msg, ok := payload.(proto.Message)
if !ok {
// check to see if we are dealing with a APIv1 message
msgV1, ok := payload.(protov1.Message)
if !ok {
return ""
}
msg = protov1.MessageV2(msgV1)
}
marshaler := prototext.MarshalOptions{
Multiline: true,
Indent: " ",
}
b, err := marshaler.Marshal(msg)
if err != nil {
return ""
}
return string(b)
}
type contextKey string
var ctxKeyRPCState = contextKey("rpcState") //nolint:gochecknoglobals
type rpcState struct {
tagsAndMeta *metrics.TagsAndMeta
}
func withRPCState(ctx context.Context, rpcState *rpcState) context.Context {
return context.WithValue(ctx, ctxKeyRPCState, rpcState)
}
func getRPCState(ctx context.Context) *rpcState {
v := ctx.Value(ctxKeyRPCState)
if v == nil {
return nil
}
return v.(*rpcState) //nolint: forcetypeassert
}