/
extract.go
161 lines (139 loc) · 3.93 KB
/
extract.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Package extract defines all the possible extraction methods that can be applied on a response.
//
// Signatures
//
// it is important for consistency that each extractor follow the same function signature:
// func(r RequestTest, resp *http.Response, env map[string]interface{})
//
//
// this will allow easier refactoring or interfacing later on if this becomes necessary.
package domain
import (
"fmt"
"io/ioutil"
"net/http"
"github.com/pkg/errors"
)
func ProcessResponse(r *RequestTest, resp *http.Response, env map[string]interface{}) error {
if err := StatusCode(r, resp, env); err != nil {
return err
}
if err := Header(r, resp, env); err != nil {
return err
}
return Body(r, resp, env)
}
// StatusCode - extracts the status code and checks it against the expected value
func StatusCode(r *RequestTest, resp *http.Response, _ map[string]interface{}) error {
if resp == nil {
return errors.New("unexpected nil response")
}
if r.WantsCode != 0 && r.WantsCode != resp.StatusCode {
return errors.Errorf("expected response code: %s, but got: %s",
http.StatusText(r.WantsCode),
http.StatusText(resp.StatusCode),
)
}
return nil
}
// Payload - checks the body against the expected value
func Body(r *RequestTest, resp *http.Response, env map[string]interface{}) error {
if resp == nil {
return errors.New("unexpected nil response")
}
// If we're unable to ascertain the body type, we won't
// be able to extract anything and needn't bother reading
// the response body.
bodyGetter, err := NewBodyGetter(resp)
if err != nil {
return errors.Wrap(err, "creating body getter")
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrap(err, "reading response body")
}
defer resp.Body.Close()
for k, v := range r.Body {
path, expected, set, err := extractParam(k, v)
if err != nil {
return errors.Wrap(err, "extracting body param")
}
actual, err := bodyGetter.Get(path, respBody)
if err != nil {
return err
}
if err = equals(expected, actual); err != nil {
return errors.Wrap(err, "assertion failed")
}
if set != "" {
env[set] = actual
}
}
return nil
}
func getFirst(m map[string]interface{}) (key, val string, err error) {
for k, v := range m {
val, ok := v.(string)
if !ok {
return "", "", errors.Errorf("expected string but got: %T", val)
}
return k, val, nil
}
return "", "", nil
}
func extractParam(key string, value interface{}) (path, expected, set string, err error) {
switch x := value.(type) {
case map[string]interface{}: // TOML unmarshals to this.
if path, expected, err = getFirst(x); err != nil {
return
}
return path, expected, key, err
case map[interface{}]interface{}: // YAML unmarshals to this.
ms := convertInterfaceMap(x)
if path, expected, err = getFirst(ms); err != nil {
return
}
return path, expected, key, err
case string:
return key, x, "", err
default:
return "", "", "", errors.Errorf("expected string but got: %T", x)
}
}
func convertInterfaceMap(in map[interface{}]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for k, v := range in {
result[fmt.Sprintf("%v", k)] = v
}
return result
}
// Header - extracts a header value and checks it against the expected value
func Header(r *RequestTest, resp *http.Response, env map[string]interface{}) error {
if resp == nil {
return errors.New("unexpected nil response")
}
headerGetter := &HeaderGetter{}
for k, v := range r.Head {
path, expected, set, err := extractParam(k, v)
if err != nil {
return errors.Wrap(err, "extracting body param")
}
actual, err := headerGetter.Get(path, resp.Header)
if err != nil {
return err
}
if err = equals(expected, actual); err != nil {
return errors.Wrap(err, "assertion failed")
}
if set != "" {
env[set] = actual
}
}
return nil
}
func equals(exp string, act string) (err error) {
if exp != act {
return errors.Errorf("\n\texp: %v\n\tgot: %v", exp, act)
}
return
}