-
Notifications
You must be signed in to change notification settings - Fork 2
/
middleware.go
205 lines (183 loc) · 6.21 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package server
import (
"fmt"
"io"
"sync/atomic"
"time"
log "github.com/cihub/seelog"
"github.com/hailo-platform/H2O/protobuf/proto"
errors "github.com/hailo-platform/H2O/platform/errors"
"github.com/hailo-platform/H2O/platform/stats"
inst "github.com/hailo-platform/H2O/service/instrumentation"
trace "github.com/hailo-platform/H2O/service/trace"
traceproto "github.com/hailo-platform/H2O/platform/proto/trace"
)
// commonLogHandler will log to w using the Apache common log format
// http://httpd.apache.org/docs/2.2/logs.html#common
// If w is nil, nothing will be logged
func commonLoggerMiddleware(w io.Writer) Middleware {
return func(ep *Endpoint, h Handler) Handler {
// If no writer is passed to middleware just return the handler
if w == nil {
return h
}
return func(req *Request) (proto.Message, errors.Error) {
var userId string
if req.Auth() != nil && req.Auth().AuthUser() != nil {
userId = req.Auth().AuthUser().Id
}
var err errors.Error
var m proto.Message
// In defer in case the handler panics
defer func() {
status := uint32(200)
if err != nil {
status = err.HttpCode()
}
size := 0
if m != nil {
log.Debug(m.String())
size = len(m.String())
}
fmt.Fprintf(w, "%s - %s [%s] \"%s %s %s\" %d %d\n",
req.From(),
userId,
time.Now().Format("02/Jan/2006:15:04:05 -0700"),
"GET", // Treat them all as GET's at the moment
req.Endpoint(),
"HTTP/1.0", // Has to be HTTP or apachetop ignores it
status,
size,
)
}()
// Execute the actual handler
m, err = h(req)
return m, err
}
}
}
// tokenConstrainedMiddleware limits the max concurrent requests handled per caller
func tokenConstrainedMiddleware(ep *Endpoint, h Handler) Handler {
return func(req *Request) (proto.Message, errors.Error) {
callerName := req.From()
if callerName == "" {
callerName = "unknown"
}
tokenBucketName := fmt.Sprintf("server.tokens.%s", callerName)
reqsBucketName := fmt.Sprintf("server.inflightrequests.%s", callerName)
tokC := tokensChan(callerName)
select {
case t := <-tokC:
defer func() {
atomic.AddUint64(&inFlightRequests, ^uint64(0)) // This is actually a subtraction
tokC <- t // Return the token to the pool
}()
nowInFlight := atomic.AddUint64(&inFlightRequests, 1) // Update active request counters
inst.Gauge(1.0, tokenBucketName, len(tokC))
inst.Gauge(1.0, reqsBucketName, int(nowInFlight))
return h(req)
case <-time.After(time.Duration(ep.Mean) * time.Millisecond):
inst.Gauge(1.0, tokenBucketName, len(tokC))
inst.Counter(1.0, "server.error.capacity", 1)
return nil, errors.InternalServerError("com.hailocab.kernel.server.capacity",
fmt.Sprintf("Server %v out of capacity", Name))
}
}
}
// instrumentedHandler wraps the handler to provide instrumentation
func instrumentedMiddleware(ep *Endpoint, h Handler) Handler {
return func(req *Request) (rsp proto.Message, err errors.Error) {
start := time.Now()
// In a defer in case the handler panics
defer func() {
stats.Record(ep, err, time.Since(start))
if err == nil {
inst.Timing(1.0, "success."+ep.Name, time.Since(start))
return
}
inst.Counter(1.0, fmt.Sprintf("server.error.%s", err.Code()), 1)
switch err.Type() {
case errors.ErrorBadRequest, errors.ErrorNotFound:
// Ignore errors that are caused by clients
// TODO: consider a new stat for clienterror?
inst.Timing(1.0, "success."+ep.Name, time.Since(start))
return
default:
inst.Timing(1.0, "error."+ep.Name, time.Since(start))
}
}()
rsp, err = h(req)
return rsp, err
}
}
// tracingMiddleware adds tracing to a handler
func tracingMiddleware(ep *Endpoint, h Handler) Handler {
return func(req *Request) (rsp proto.Message, err errors.Error) {
start := time.Now()
traceIn(req)
defer traceOut(req, rsp, err, time.Since(start))
rsp, err = h(req)
return rsp, err
}
}
// authMiddleware only calls the handler is the auth check passes
func authMiddleware(ep *Endpoint, h Handler) Handler {
return func(req *Request) (proto.Message, errors.Error) {
if err := ep.Authoriser.Authorise(req); err != nil {
return nil, err
}
req.Auth().SetAuthorised(true)
return h(req)
}
}
func waitGroupMiddleware(ep *Endpoint, h Handler) Handler {
return func(req *Request) (proto.Message, errors.Error) {
requestsWg.Add(1)
defer requestsWg.Done()
return h(req)
}
}
// traceIn traces a request inbound to a service to handle
func traceIn(req *Request) {
if req.shouldTrace() {
go trace.Send(&traceproto.Event{
Timestamp: proto.Int64(time.Now().UnixNano()),
TraceId: proto.String(req.TraceID()),
Type: traceproto.Event_IN.Enum(),
MessageId: proto.String(req.MessageID()),
ParentMessageId: proto.String(req.ParentMessageID()),
From: proto.String(req.From()),
To: proto.String(fmt.Sprintf("%v.%v", req.Service(), req.Endpoint())),
Hostname: proto.String(hostname),
Az: proto.String(az),
Payload: proto.String(""), // @todo
HandlerInstanceId: proto.String(InstanceID),
PersistentTrace: proto.Bool(req.TraceShouldPersist()),
})
}
}
// traceOut traces a request outbound from a service handler
func traceOut(req *Request, msg proto.Message, err errors.Error, d time.Duration) {
if req.shouldTrace() {
e := &traceproto.Event{
Timestamp: proto.Int64(time.Now().UnixNano()),
TraceId: proto.String(req.TraceID()),
Type: traceproto.Event_OUT.Enum(),
MessageId: proto.String(req.MessageID()),
ParentMessageId: proto.String(req.ParentMessageID()),
From: proto.String(req.From()),
To: proto.String(fmt.Sprintf("%v.%v", req.Service(), req.Endpoint())),
Hostname: proto.String(hostname),
Az: proto.String(az),
Payload: proto.String(""), // @todo
HandlerInstanceId: proto.String(InstanceID),
Duration: proto.Int64(int64(d)),
PersistentTrace: proto.Bool(req.TraceShouldPersist()),
}
if err != nil {
e.ErrorCode = proto.String(err.Code())
e.ErrorDescription = proto.String(err.Description())
}
go trace.Send(e)
}
}