forked from getlantern/lantern
/
fwd.go
151 lines (132 loc) · 3.49 KB
/
fwd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// package forwarder implements http handler that forwards requests to remote server
// and serves back the response
package forward
import (
"io"
"net/http"
"net/url"
"os"
"strconv"
"time"
"github.com/mailgun/oxy/utils"
)
// ReqRewriter can alter request headers and body
type ReqRewriter interface {
Rewrite(r *http.Request)
}
type optSetter func(f *Forwarder) error
func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error {
f.passHost = b
return nil
}
}
func RoundTripper(r http.RoundTripper) optSetter {
return func(f *Forwarder) error {
f.roundTripper = r
return nil
}
}
func Rewriter(r ReqRewriter) optSetter {
return func(f *Forwarder) error {
f.rewriter = r
return nil
}
}
// ErrorHandler is a functional argument that sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) optSetter {
return func(f *Forwarder) error {
f.errHandler = h
return nil
}
}
func Logger(l utils.Logger) optSetter {
return func(f *Forwarder) error {
f.log = l
return nil
}
}
type Forwarder struct {
errHandler utils.ErrorHandler
roundTripper http.RoundTripper
rewriter ReqRewriter
log utils.Logger
passHost bool
}
func New(setters ...optSetter) (*Forwarder, error) {
f := &Forwarder{}
for _, s := range setters {
if err := s(f); err != nil {
return nil, err
}
}
if f.roundTripper == nil {
f.roundTripper = http.DefaultTransport
}
if f.rewriter == nil {
h, err := os.Hostname()
if err != nil {
h = "localhost"
}
f.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h}
}
if f.log == nil {
f.log = utils.NullLogger
}
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}
return f, nil
}
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
start := time.Now().UTC()
response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL))
if err != nil {
f.log.Errorf("Error forwarding to %v, err: %v", req.URL, err)
f.errHandler.ServeHTTP(w, req, err)
return
}
if req.TLS != nil {
f.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start),
req.TLS.Version,
req.TLS.DidResume,
req.TLS.CipherSuite,
req.TLS.ServerName)
} else {
f.log.Infof("Round trip: %v, code: %v, duration: %v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start))
}
utils.CopyHeaders(w.Header(), response.Header)
w.WriteHeader(response.StatusCode)
written, _ := io.Copy(w, response.Body)
if written != 0 {
w.Header().Set(ContentLength, strconv.FormatInt(written, 10))
}
response.Body.Close()
}
func (f *Forwarder) copyRequest(req *http.Request, u *url.URL) *http.Request {
outReq := new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Host = u.Host
outReq.URL.Opaque = req.RequestURI
// raw query is already included in RequestURI, so ignore it to avoid dupes
outReq.URL.RawQuery = ""
// Do not pass client Host header unless optsetter PassHostHeader is set.
if f.passHost != true {
outReq.Host = u.Host
}
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Overwrite close flag so we can keep persistent connection for the backend servers
outReq.Close = false
outReq.Header = make(http.Header)
utils.CopyHeaders(outReq.Header, req.Header)
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
return outReq
}