forked from quic-go/quic-go
/
request_writer.go
319 lines (290 loc) · 8.36 KB
/
request_writer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
package http3
import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/PKURio/quic-go"
"github.com/PKURio/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
)
const bodyCopyBufferSize = 8 * 1024
type requestWriter struct {
mutex sync.Mutex
encoder *qpack.Encoder
headerBuf *bytes.Buffer
logger utils.Logger
}
func newRequestWriter(logger utils.Logger) *requestWriter {
headerBuf := &bytes.Buffer{}
encoder := qpack.NewEncoder(headerBuf)
return &requestWriter{
encoder: encoder,
headerBuf: headerBuf,
logger: logger,
}
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
buf := &bytes.Buffer{}
if err := w.writeHeaders(buf, req, gzip); err != nil {
return err
}
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
// TODO: add support for trailers
if req.Body == nil {
str.Close()
return nil
}
// send the request body asynchronously
go func() {
defer req.Body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, rerr := req.Body.Read(b)
if n == 0 {
if rerr == nil {
continue
} else if rerr == io.EOF {
break
}
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if _, err := str.Write(b[:n]); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if rerr != nil {
if rerr == io.EOF {
break
}
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
w.logger.Errorf("Error writing request: %s", rerr)
return
}
}
str.Close()
}()
return nil
}
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
w.mutex.Lock()
defer w.mutex.Unlock()
defer w.encoder.Close()
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
return err
}
buf := &bytes.Buffer{}
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
hf.Write(buf)
if _, err := wr.Write(buf.Bytes()); err != nil {
return err
}
if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
return err
}
w.headerBuf.Reset()
return nil
}
// copied from net/transport.go
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
f(":method", req.Method)
if req.Method != "CONNECT" {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
strings.EqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if strings.EqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
f("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", defaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
// TODO: check maximum header list size
// if hlSize > cc.peerMaxHeaderListSize {
// return errRequestHeaderListSize
// }
// trace := httptrace.ContextClientTrace(req.Context())
// traceHeaders := traceHasWroteHeaderField(trace)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name = strings.ToLower(name)
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
// if traceHeaders {
// traceWroteHeaderField(trace, name, value)
// }
})
return nil
}
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// *) a non-empty string starting with '/'
// *) the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}