/
mock.go
371 lines (324 loc) · 9.52 KB
/
mock.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package mock_server
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"time"
)
const (
mockReadError = "mock-read-error"
)
// Server is a wrapper around an httptest.Server.
// The serving of requests is not safe for concurrent use
// which is ok for right now as each test creates is own
// server and doesn't create additional go routines.
type Server struct {
srv *httptest.Server
// static is the static response, if this is not nil it's always returned.
static *mockResponse
// respLock synchronizes access to resp
respLock *sync.RWMutex
// resp is the queue of responses. each response is taken from the front.
resp []mockResponse
// count tracks the number of requests that have been made.
count int
// determines whether all requests will be routed to the httptest Server by changing the Host of each request
routeAllRequestsToMockServer bool
}
func newServer() *Server {
return &Server{respLock: &sync.RWMutex{}}
}
// NewServer creates a new Server object.
// The returned close func must be called when the server is no longer needed.
func NewServer(cfg ...ServerOption) (*Server, func()) {
s := newServer()
s.srv = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP))
for _, c := range cfg {
c.apply(s)
}
s.srv.Start()
return s, func() { s.srv.Close() }
}
// NewTLSServer creates a new Server object and applies any additional ServerOption
// configurations provided on the Server.
// The returned close func must be called when the server is no longer needed.
// It will return nil for both the server and close func if it encountered an error
// when configuring HTTP2 for the new TLS server.
func NewTLSServer(cfg ...ServerOption) (*Server, func()) {
s := newServer()
s.srv = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP))
for _, c := range cfg {
c.apply(s)
}
s.srv.StartTLS()
return s, func() { s.srv.Close() }
}
// ServerConfig returns the server config. Please note this should not be
// modified since Start or StartTLS has already been called on the server.
func (s *Server) ServerConfig() *http.Server {
return s.srv.Config
}
// returns true if the next response is an error response
func (s *Server) isErrorResp() bool {
// always favor static response
if s.static != nil {
return s.static.err != nil
}
s.respLock.RLock()
defer s.respLock.RUnlock()
if len(s.resp) > 0 {
return s.resp[0].err != nil
}
panic("no more responses")
}
// returns the static response or the next response in the queue
func (s *Server) getResponse() mockResponse {
// always favor static response
if s.static != nil {
return *s.static
}
if len(s.resp) > 0 {
// pop off first response and return it
s.respLock.Lock()
defer s.respLock.Unlock()
resp := s.resp[0]
s.resp = s.resp[1:]
return resp
}
panic("no more responses")
}
// URL returns the endpoint of the test server in URL format.
func (s *Server) URL() string {
return s.srv.URL
}
// Do implements the azcore.Transport interface on Server.
// Calling this when the response queue is empty and no static
// response has been set will cause a panic.
func (s *Server) Do(req *http.Request) (*http.Response, error) {
s.count++
// error responses are returned here
if s.isErrorResp() {
resp := s.getResponse()
return nil, resp.err
}
var err error
var resp *http.Response
if s.routeAllRequestsToMockServer {
var srvUrl *url.URL
originalURL := req.URL
mockUrl := *req.URL
srvUrl, err = url.Parse(s.srv.URL)
if err != nil {
return nil, fmt.Errorf("Unable to parse the test server URL: %v", err)
}
mockUrl.Host = srvUrl.Host
mockUrl.Scheme = srvUrl.Scheme
req.URL = &mockUrl
resp, err = s.srv.Client().Do(req)
req.URL = originalURL
} else {
resp, err = s.srv.Client().Do(req)
}
if err != nil {
return resp, err
}
// wrap the response body in a readFailer if the mock-read-error header is set
if resp.Header.Get(mockReadError) != "" {
resp.Body = &readFailer{wrapped: resp.Body}
resp.Header.Del(mockReadError)
}
return resp, err
}
func (s *Server) serveHTTP(w http.ResponseWriter, req *http.Request) {
var resp mockResponse
for {
// grab next response from the queue
resp = s.getResponse()
if resp.pred == nil {
// no predicate, we're done
break
} else if resp.pred(req) {
// response applies to this request, so remove the next response
s.getResponse()
break
}
}
if resp.delay > 0 {
time.Sleep(resp.delay)
}
err := resp.write(w)
if err != nil {
panic(err)
}
}
// Requests returns the number of times an HTTP request was made.
func (s *Server) Requests() int {
return s.count
}
// AppendError appends the error to the end of the response queue.
func (s *Server) AppendError(err error) {
s.resp = append(s.resp, mockResponse{err: err})
}
// RepeatError appends the error n number of times to the end of the response queue.
func (s *Server) RepeatError(n int, err error) {
for i := 0; i < n; i++ {
s.AppendError(err)
}
}
// SetError indicates the same error should always be returned.
// Any responses set via other methods will be ignored.
func (s *Server) SetError(err error) {
s.static = &mockResponse{err: err}
}
// AppendResponse appends the response to the end of the response queue.
// If no options are provided the default response is an http.StatusOK.
func (s *Server) AppendResponse(opts ...ResponseOption) {
mr := mockResponse{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&mr)
}
s.resp = append(s.resp, mr)
}
// RepeatResponse appends the response n number of times to the end of the response queue.
// If no options are provided the default response is an http.StatusOK.
func (s *Server) RepeatResponse(n int, opts ...ResponseOption) {
for i := 0; i < n; i++ {
s.AppendResponse(opts...)
}
}
// SetResponse indicates the same response should always be returned.
// Any responses set via other methods will be ignored.
// If no options are provided the default response is an http.StatusOK.
// NOTE: does not support WithPredicate(), will cause a panic.
func (s *Server) SetResponse(opts ...ResponseOption) {
mr := mockResponse{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&mr)
}
if mr.pred != nil {
panic("WithPredicate not supported for static responses")
}
s.static = &mr
}
// ServerOption is an abstraction for configuring a mock Server.
type ServerOption interface {
apply(s *Server)
}
type fnSrvOpt func(*Server)
func (fn fnSrvOpt) apply(s *Server) {
fn(s)
}
func WithTransformAllRequestsToTestServerUrl() ServerOption {
return fnSrvOpt(func(s *Server) {
s.routeAllRequestsToMockServer = true
})
}
// WithTLSConfig sets the given TLS config on server.
func WithTLSConfig(cfg *tls.Config) ServerOption {
return fnSrvOpt(func(s *Server) {
s.srv.TLS = cfg
})
}
// WithHTTP2Enabled sets the HTTP2Enabled field on the testserver to the boolean value provided.
func WithHTTP2Enabled(enabled bool) ServerOption {
return fnSrvOpt(func(s *Server) {
s.srv.EnableHTTP2 = enabled
})
}
// ResponseOption is an abstraction for configuring a mock HTTP response.
type ResponseOption interface {
apply(mr *mockResponse)
}
type fnRespOpt func(*mockResponse)
func (fn fnRespOpt) apply(mr *mockResponse) {
fn(mr)
}
// ResponsePredicate is a predicate that's invoked in response to an HTTP request.
type ResponsePredicate func(*http.Request) bool
type mockResponse struct {
code int
body []byte
headers http.Header
err error
rerr bool
delay time.Duration
pred ResponsePredicate
}
func (mr mockResponse) write(w http.ResponseWriter) error {
if len(mr.headers) > 0 {
for k, v := range mr.headers {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
}
if mr.rerr {
w.Header().Add(mockReadError, "true")
}
w.WriteHeader(mr.code)
if mr.body != nil {
_, err := w.Write(mr.body)
if err != nil {
return err
}
}
return nil
}
// WithStatusCode sets the HTTP response's status code to the specified value.
func WithStatusCode(c int) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.code = c
})
}
// WithBody sets the HTTP response's body to the specified value.
func WithBody(b []byte) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.body = b
})
}
// WithHeader adds the specified header and value to the HTTP response.
func WithHeader(k, v string) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.headers.Add(k, v)
})
}
// WithSlowResponse will sleep for the specified duration before returning the HTTP response.
func WithSlowResponse(d time.Duration) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.delay = d
})
}
// WithBodyReadError returns a response that will fail when reading the body.
func WithBodyReadError() ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.rerr = true
})
}
// WithPredicate invokes the specified predicate func on the HTTP request.
// If the predicate returns true, the response associated with the predicate is
// returned and the next response is removed from the queue. When false, the
// associated response is ignored and the next one is returned.
// NOTE: not supported for static responses, will cause a panic.
func WithPredicate(p ResponsePredicate) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.pred = p
})
}
type readFailer struct {
wrapped io.ReadCloser
}
func (r *readFailer) Close() error {
return r.wrapped.Close()
}
func (r *readFailer) Read(p []byte) (int, error) {
return 0, errors.New("mock read failure")
}