/
tlsdialer.go
245 lines (211 loc) · 7.29 KB
/
tlsdialer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
// package tlsdialer contains a customized version of crypto/tls.Dial that
// allows control over whether or not to send the ServerName extension in the
// client handshake.
package tlsdialer
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/getlantern/golog"
)
var (
log = golog.LoggerFor("tlsdialer")
resolve = func(addr string) (*net.TCPAddr, error) {
resolved, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
return resolved, nil
}
dialOverride func(network, addr string, timeout time.Duration) (net.Conn, error)
)
type timeoutError struct{}
func (timeoutError) Error() string { return "tlsdialer: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// A tls.Conn along with timings for key steps in establishing that Conn
type ConnWithTimings struct {
// Conn: the conn resulting from dialing
Conn *tls.Conn
// ResolutionTime: the amount of time it took to resolve the address
ResolutionTime time.Duration
// ConnectTime: the amount of time that it took to connect the socket
ConnectTime time.Duration
// HandshakeTime: the amount of time that it took to complete the TLS
// handshake
HandshakeTime time.Duration
// ResolvedAddr: the address to which our dns lookup resolved
ResolvedAddr *net.TCPAddr
// VerifiedChains: like tls.ConnectionState.VerifiedChains
VerifiedChains [][]*x509.Certificate
}
// OverrideResolve allows overriding the DNS resolution function
func OverrideResolve(override func(addr string) (*net.TCPAddr, error)) {
resolve = override
}
// OverrideDial allows specifying a function that will be used to dial in lieu
// of a net.Dialer.
func OverrideDial(override func(network, addr string, timeout time.Duration) (net.Conn, error)) {
dialOverride = override
}
// Like crypto/tls.Dial, but with the ability to control whether or not to
// send the ServerName extension in client handshakes through the sendServerName
// flag.
//
// Note - if sendServerName is false, the VerifiedChains field on the
// connection's ConnectionState will never get populated. Use DialForTimings to
// get back a data structure that includes the verified chains.
func Dial(network, addr string, sendServerName bool, config *tls.Config) (*tls.Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, sendServerName, config)
}
// Like crypto/tls.DialWithDialer, but with the ability to control whether or
// not to send the ServerName extension in client handshakes through the
// sendServerName flag.
//
// Note - if sendServerName is false, the VerifiedChains field on the
// connection's ConnectionState will never get populated. Use DialForTimings to
// get back a data structure that includes the verified chains.
func DialWithDialer(dialer *net.Dialer, network, addr string, sendServerName bool, config *tls.Config) (*tls.Conn, error) {
result, err := DialForTimings(dialer, network, addr, sendServerName, config)
return result.Conn, err
}
// Like DialWithDialer but returns a data structure including timings and the
// verified chains.
func DialForTimings(dialer *net.Dialer, network, addr string, sendServerName bool, config *tls.Config) (*ConnWithTimings, error) {
result := &ConnWithTimings{}
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errCh chan error
if timeout != 0 {
errCh = make(chan error, 10)
time.AfterFunc(timeout, func() {
errCh <- timeoutError{}
})
}
log.Tracef("Resolving addr: %s", addr)
start := time.Now()
var err error
if timeout == 0 {
log.Tracef("Resolving immediately")
result.ResolvedAddr, err = resolve(addr)
} else {
log.Tracef("Resolving on goroutine")
resolvedCh := make(chan *net.TCPAddr, 10)
go func() {
resolved, err := resolve(addr)
log.Tracef("Resolution resulted in %s : %s", resolved, err)
resolvedCh <- resolved
errCh <- err
}()
err = <-errCh
if err == nil {
log.Tracef("No error, looking for resolved")
result.ResolvedAddr = <-resolvedCh
}
}
if err != nil {
return result, err
}
result.ResolutionTime = time.Now().Sub(start)
log.Tracef("Resolved addr %s to %s in %s", addr, result.ResolvedAddr, result.ResolutionTime)
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return result, fmt.Errorf("Unable to split host and port for %v: %v", addr, err)
}
log.Tracef("Dialing %s %s (%s)", network, addr, result.ResolvedAddr)
start = time.Now()
resolvedAddr := result.ResolvedAddr.String()
var rawConn net.Conn
if dialOverride != nil {
log.Trace("Dialing with dialOverride")
rawConn, err = dialOverride(network, resolvedAddr, timeout)
} else {
rawConn, err = dialer.Dial(network, resolvedAddr)
}
if err != nil {
return result, err
}
result.ConnectTime = time.Now().Sub(start)
log.Tracef("Dialed in %s", result.ConnectTime)
if config == nil {
config = &tls.Config{}
}
serverName := config.ServerName
if serverName == "" {
log.Trace("No ServerName set, inferring from the hostname to which we're connecting")
serverName = hostname
}
log.Tracef("ServerName is: %s", serverName)
log.Trace("Copying config so that we can tweak it")
configCopy := new(tls.Config)
*configCopy = *config
if sendServerName {
log.Tracef("Setting ServerName to %s and relying on the usual logic in tls.Conn.Handshake() to do verification", serverName)
configCopy.ServerName = serverName
} else {
log.Trace("Clearing ServerName and disabling verification in tls.Conn.Handshake(). We'll verify manually after handshaking.")
configCopy.ServerName = ""
configCopy.InsecureSkipVerify = true
}
conn := tls.Client(rawConn, configCopy)
start = time.Now()
if timeout == 0 {
log.Trace("Handshaking immediately")
err = conn.Handshake()
} else {
log.Trace("Handshaking on goroutine")
go func() {
errCh <- conn.Handshake()
}()
err = <-errCh
}
if err == nil {
result.HandshakeTime = time.Now().Sub(start)
}
log.Tracef("Finished handshaking in: %s", result.HandshakeTime)
if err == nil && !config.InsecureSkipVerify {
if sendServerName {
log.Trace("Depending on certificate verification in tls.Conn.Handshake()")
result.VerifiedChains = conn.ConnectionState().VerifiedChains
} else {
log.Trace("Manually verifying certificates")
configCopy.ServerName = ""
result.VerifiedChains, err = verifyServerCerts(conn, serverName, configCopy)
}
}
if err != nil {
log.Trace("Handshake or verification error, closing underlying connection")
if err := rawConn.Close(); err != nil {
log.Debugf("Unable to close connection: %v", err)
}
return result, err
}
result.Conn = conn
return result, nil
}
func verifyServerCerts(conn *tls.Conn, serverName string, config *tls.Config) ([][]*x509.Certificate, error) {
certs := conn.ConnectionState().PeerCertificates
opts := x509.VerifyOptions{
Roots: config.RootCAs,
CurrentTime: time.Now(),
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
return certs[0].Verify(opts)
}