/
std_net.go
193 lines (172 loc) · 5.15 KB
/
std_net.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
package syntax
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/arr-ai/arrai/pkg/arraictx"
"github.com/arr-ai/arrai/rel"
"github.com/arr-ai/arrai/tools"
)
type httpConfig struct {
header map[string][]string
}
func stdNet() rel.Attr {
return rel.NewTupleAttr(
"net",
rel.NewTupleAttr(
"http",
createFunc2Attr("get", func(ctx context.Context, configArg, urlArg rel.Value) (rel.Value, error) {
if arraictx.IsCompiling(ctx) {
return nil, errors.New("//net.get is disabled during compilation")
}
config, err := parseConfig(configArg)
if err != nil {
return nil, err
}
url, err := parseURL(urlArg)
if err != nil {
return nil, err
}
return get(url, config.header)
}),
createFunc3Attr("post",
func(ctx context.Context, configArg, urlArg, bodyArg rel.Value) (rel.Value, error) {
if arraictx.IsCompiling(ctx) {
return nil, errors.New("//net.post is disabled during compilation")
}
config, err := parseConfig(configArg)
if err != nil {
return nil, err
}
url, err := parseURL(urlArg)
if err != nil {
return nil, err
}
body, err := parseBody(bodyArg)
if err != nil {
return nil, err
}
return post(url, config.header, body)
}),
),
)
}
// send sends a request of type method to url with headers and body and returns a value wrapping the
// response.
func send(method, url string, headers map[string][]string, body io.Reader) (rel.Value, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
if len(headers) > 0 {
req.Header = headers
}
res, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil {
return nil, err
}
return parseResponse(res)
}
// get sends a GET request and returns a value wrapping the response.
func get(url string, headers map[string][]string) (rel.Value, error) {
return send("GET", url, headers, strings.NewReader(""))
}
// post sends a POST request and returns a value wrapping the response.
func post(url string, headers map[string][]string, body io.Reader) (rel.Value, error) {
return send("POST", url, headers, body)
}
// parseConfig returns the config arg as a httpConfig.
func parseConfig(configArg rel.Value) (*httpConfig, error) {
config, ok := configArg.(*rel.GenericTuple)
if !ok {
return nil, errors.Errorf("first arg (config) must be tuple, not %s", rel.ValueTypeAsString(configArg))
}
head, ok := config.Get("header")
if !ok {
return &httpConfig{}, nil
}
header, err := parseHeader(head)
if err != nil {
return nil, err
}
return &httpConfig{header: header}, nil
}
// parseHeader returns the header of the config arg as a map.
func parseHeader(header rel.Value) (map[string][]string, error) {
headDict, ok := header.(rel.Dict)
if !ok {
return nil, errors.Errorf("header must be a dict, not %s", rel.ValueTypeAsString(headDict))
}
out := make(map[string][]string, headDict.Count())
for e := headDict.DictEnumerator(); e.MoveNext(); {
kv, vv := e.Current()
k, ok := tools.ValueAsString(kv)
if !ok {
return nil, errors.Errorf("header keys must be strings, not %s", rel.ValueTypeAsString(kv))
}
switch t := vv.(type) {
case rel.String:
out[k] = []string{t.String()}
case rel.Array:
vs := make([]string, t.Count())
for i, val := range t.Values() {
valStr, is := tools.ValueAsString(val)
if !is {
return nil, errors.Errorf(
"header values must be strings or string arrays, not arrays of %s", rel.ValueTypeAsString(val))
}
vs[i] = valStr
}
out[k] = vs
default:
return nil, errors.Errorf("header values must be strings or string arrays, not %s", rel.ValueTypeAsString(vv))
}
}
return out, nil
}
// parseURL returns the URL arg as a string.
func parseURL(urlArg rel.Value) (string, error) {
url, is := tools.ValueAsString(urlArg)
if !is {
return "", errors.Errorf("second arg (url) must be a string, not %s", rel.ValueTypeAsString(urlArg))
}
return url, nil
}
// parseBody returns the body arg as a Reader.
func parseBody(bodyArg rel.Value) (io.Reader, error) {
body, is := tools.ValueAsString(bodyArg)
if is {
return strings.NewReader(body), nil
}
bodyBytes, is := tools.ValueAsBytes(bodyArg)
if is {
return bytes.NewReader(bodyBytes), nil
}
return nil, errors.Errorf("third arg (body) must be a string or bytes, not %s", rel.ValueTypeAsString(bodyArg))
}
// parseResponse parses an HTTP response into an arr.ai value.
func parseResponse(resp *http.Response) (rel.Value, error) {
buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
defer resp.Body.Close()
entries := make([]rel.DictEntryTuple, 0, len(resp.Header))
for k, vs := range resp.Header {
vals := make([]rel.Value, len(vs))
for j, v := range vs {
vals[j] = rel.NewString([]rune(v))
}
entries = append(entries, rel.NewDictEntryTuple(rel.NewString([]rune(k)), rel.NewArray(vals...)))
}
header := rel.MustNewDict(false, entries...)
return rel.NewTuple(
rel.NewAttr("status", rel.NewString([]rune(resp.Status))),
rel.NewAttr("status_code", rel.NewNumber(float64(resp.StatusCode))),
rel.NewAttr("header", header),
rel.NewAttr("body", rel.NewBytes(buf)),
), nil
}