-
Notifications
You must be signed in to change notification settings - Fork 241
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge in DNS/dnsproxy from 5251-close-ups to master Updates AdguardTeam/AdGuardHome#5251. Squashed commit of the following: commit c4f01d3 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Dec 14 17:57:19 2022 +0300 upstream: enhance filtered errs commit ce33519 Merge: 175680b 0ce51f5 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Dec 14 13:29:27 2022 +0300 Merge branch 'master' into 5251-close-ups commit 175680b Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Dec 13 15:06:01 2022 +0300 upstream: fix doc commit 9ecfc0b Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Dec 13 14:59:04 2022 +0300 all: imp docs commit b275d56 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Dec 13 02:11:13 2022 +0300 upstream: imp dot commit 0ed84f4 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 20:40:35 2022 +0300 fastip: fix golangci-lint issues commit c5ed13b Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 20:15:45 2022 +0300 all: fix staticcheck issues commit 6c8b28c Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 20:05:24 2022 +0300 upstream: imp more commit e4a2374 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 19:01:57 2022 +0300 upstream: upd golibs, imp code commit cae1610 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 16:27:21 2022 +0300 upstream: filter dot errs commit 51005ab Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Dec 12 14:10:23 2022 +0300 upstream: use sync pool commit 81e907c Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Sat Dec 10 03:37:21 2022 +0300 upstream: imp dot
- Loading branch information
1 parent
0ce51f5
commit dc6b896
Showing
10 changed files
with
320 additions
and
285 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,134 +1,247 @@ | ||
package upstream | ||
|
||
import ( | ||
"context" | ||
"crypto/tls" | ||
"fmt" | ||
"io" | ||
"net" | ||
"net/url" | ||
"os" | ||
"runtime" | ||
"sync" | ||
"time" | ||
|
||
"github.com/AdguardTeam/golibs/errors" | ||
"github.com/AdguardTeam/golibs/log" | ||
"github.com/miekg/dns" | ||
) | ||
|
||
// dialTimeout is the global timeout for establishing a TLS connection. | ||
// TODO(ameshkov): use bootstrap timeout instead. | ||
const dialTimeout = 10 * time.Second | ||
|
||
// dnsOverTLS is a struct that implements the Upstream interface for the | ||
// DNS-over-TLS protocol. | ||
type dnsOverTLS struct { | ||
boot *bootstrapper | ||
pool *TLSPool | ||
poolMu sync.Mutex | ||
// boot resolves the hostname upstream addresses. | ||
boot *bootstrapper | ||
|
||
// connsMu protects conns. | ||
connsMu sync.Mutex | ||
|
||
// conns stores the connections ready for reuse. Don't use [sync.Pool] | ||
// here, since there is no need to deallocate these connections. | ||
// | ||
// TODO(e.burkov, ameshkov): Currently connections just stored in FILO | ||
// order, which eventually makes most of them unusable due to timeouts. | ||
// This leads to weak performance for all exchanges coming across such | ||
// connections. | ||
conns []net.Conn | ||
} | ||
|
||
// type check | ||
var _ Upstream = (*dnsOverTLS)(nil) | ||
|
||
// newDoT returns the DNS-over-TLS Upstream. | ||
func newDoT(uu *url.URL, opts *Options) (u Upstream, err error) { | ||
addPort(uu, defaultPortDoT) | ||
func newDoT(u *url.URL, opts *Options) (ups Upstream, err error) { | ||
addPort(u, defaultPortDoT) | ||
|
||
var b *bootstrapper | ||
b, err = urlToBoot(uu, opts) | ||
boot, err := urlToBoot(u, opts) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating tls bootstrapper: %w", err) | ||
} | ||
|
||
u = &dnsOverTLS{boot: b} | ||
ups = &dnsOverTLS{ | ||
boot: boot, | ||
} | ||
|
||
runtime.SetFinalizer(u, (*dnsOverTLS).Close) | ||
runtime.SetFinalizer(ups, (*dnsOverTLS).Close) | ||
|
||
return u, nil | ||
return ups, nil | ||
} | ||
|
||
// Address implements the Upstream interface for *dnsOverTLS. | ||
// Address implements the [Upstream] interface for *dnsOverTLS. | ||
func (p *dnsOverTLS) Address() string { return p.boot.URL.String() } | ||
|
||
// Exchange implements the Upstream interface for *dnsOverTLS. | ||
// Exchange implements the [Upstream] interface for *dnsOverTLS. | ||
func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { | ||
pool := p.getPool() | ||
|
||
poolConn, err := pool.Get() | ||
conn, err := p.conn() | ||
if err != nil { | ||
return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err) | ||
return nil, fmt.Errorf("getting conn to %s: %w", p.Address(), err) | ||
} | ||
|
||
logBegin(p.Address(), m) | ||
reply, err = p.exchangeConn(poolConn, m) | ||
logFinish(p.Address(), err) | ||
|
||
reply, err = p.exchangeWithConn(conn, m) | ||
if err != nil { | ||
log.Tracef("The TLS connection is expired due to %s", err) | ||
// The pooled connection might have been closed already, see | ||
// https://github.com/AdguardTeam/dnsproxy/issues/3. The following | ||
// connection from pool may also be malformed, so dial a new one. | ||
|
||
// The pooled connection might have been closed already (see https://github.com/AdguardTeam/dnsproxy/issues/3) | ||
// So we're trying to re-connect right away here. | ||
// We are forcing creation of a new connection instead of calling Get() again | ||
// as there's no guarantee that other pooled connections are intact | ||
poolConn, err = pool.Create() | ||
err = errors.WithDeferred(err, conn.Close()) | ||
log.Debug("dot upstream: bad conn from pool: %s", err) | ||
|
||
// Retry. | ||
conn, err = p.dial() | ||
if err != nil { | ||
return nil, fmt.Errorf("creating new connection to %s: %w", p.Address(), err) | ||
return nil, fmt.Errorf("dialing conn to %s: %w", p.Address(), err) | ||
} | ||
|
||
// Retry sending the DNS request | ||
logBegin(p.Address(), m) | ||
reply, err = p.exchangeConn(poolConn, m) | ||
logFinish(p.Address(), err) | ||
reply, err = p.exchangeWithConn(conn, m) | ||
if err != nil { | ||
return reply, errors.WithDeferred(err, conn.Close()) | ||
} | ||
} | ||
|
||
if err == nil { | ||
pool.Put(poolConn) | ||
} | ||
return reply, err | ||
p.putBack(conn) | ||
|
||
return reply, nil | ||
} | ||
|
||
// Close implements the Upstream interface for *dnsOverTLS. | ||
// Close implements the [Upstream] interface for *dnsOverTLS. | ||
func (p *dnsOverTLS) Close() (err error) { | ||
p.poolMu.Lock() | ||
defer p.poolMu.Unlock() | ||
|
||
runtime.SetFinalizer(p, nil) | ||
|
||
if p.pool == nil { | ||
return nil | ||
p.connsMu.Lock() | ||
defer p.connsMu.Unlock() | ||
|
||
var closeErrs []error | ||
for _, conn := range p.conns { | ||
closeErr := conn.Close() | ||
if closeErr != nil && isCriticalTCP(closeErr) { | ||
closeErrs = append(closeErrs, closeErr) | ||
} | ||
} | ||
|
||
if len(closeErrs) > 0 { | ||
return errors.List("closing tls conns", closeErrs...) | ||
} | ||
|
||
return p.pool.Close() | ||
return nil | ||
} | ||
|
||
func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { | ||
// conn returns the first available connection from the pool if there is any, or | ||
// dials a new one otherwise. | ||
func (p *dnsOverTLS) conn() (conn net.Conn, err error) { | ||
// Dial a new connection outside the lock, if needed. | ||
defer func() { | ||
if err == nil { | ||
return | ||
} | ||
|
||
if cerr := conn.Close(); cerr != nil { | ||
err = &errors.Pair{Returned: err, Deferred: cerr} | ||
if conn == nil { | ||
conn, err = p.dial() | ||
} | ||
}() | ||
|
||
p.connsMu.Lock() | ||
defer p.connsMu.Unlock() | ||
|
||
l := len(p.conns) | ||
if l == 0 { | ||
return nil, nil | ||
} | ||
|
||
p.conns, conn = p.conns[:l-1], p.conns[l-1] | ||
|
||
err = conn.SetDeadline(time.Now().Add(dialTimeout)) | ||
if err != nil { | ||
log.Debug("dot upstream: setting deadline to conn from pool: %s", err) | ||
|
||
// If deadLine can't be updated it means that connection was already | ||
// closed. | ||
return nil, nil | ||
} | ||
|
||
log.Debug("dot upstream: using existing conn %s", conn.RemoteAddr()) | ||
|
||
return conn, nil | ||
} | ||
|
||
func (p *dnsOverTLS) putBack(conn net.Conn) { | ||
p.connsMu.Lock() | ||
defer p.connsMu.Unlock() | ||
|
||
p.conns = append(p.conns, conn) | ||
} | ||
|
||
// exchangeWithConn tries to exchange the query using conn. | ||
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { | ||
addr := p.Address() | ||
|
||
logBegin(addr, m) | ||
defer func() { logFinish(addr, err) }() | ||
|
||
dnsConn := dns.Conn{Conn: conn} | ||
|
||
err = dnsConn.WriteMsg(m) | ||
if err != nil { | ||
return nil, fmt.Errorf("sending request to %s: %w", p.Address(), err) | ||
return nil, fmt.Errorf("sending request to %s: %w", addr, err) | ||
} | ||
|
||
reply, err = dnsConn.ReadMsg() | ||
if err != nil { | ||
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err) | ||
return nil, fmt.Errorf("reading response from %s: %w", addr, err) | ||
} else if reply.Id != m.Id { | ||
err = dns.ErrId | ||
return reply, dns.ErrId | ||
} | ||
|
||
return reply, err | ||
} | ||
|
||
func (p *dnsOverTLS) getPool() (pool *TLSPool) { | ||
p.poolMu.Lock() | ||
defer p.poolMu.Unlock() | ||
// dial dials a new connection that may be stored in pool. | ||
func (p *dnsOverTLS) dial() (conn net.Conn, err error) { | ||
tlsConfig, dialContext, err := p.boot.get() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conn, err = tlsDial(dialContext, "tcp", tlsConfig) | ||
if err != nil { | ||
return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err) | ||
} | ||
|
||
return conn, nil | ||
} | ||
|
||
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own | ||
// dialContext function to get connection. | ||
func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { | ||
// We're using bootstrapped address instead of what's passed to the | ||
// function. | ||
rawConn, err := dialContext(context.Background(), network, "") | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if p.pool == nil { | ||
p.pool = &TLSPool{boot: p.boot} | ||
// We want the timeout to cover the whole process: TCP connection and | ||
// TLS handshake dialTimeout will be used as connection deadLine. | ||
conn := tls.Client(rawConn, config) | ||
err = conn.SetDeadline(time.Now().Add(dialTimeout)) | ||
if err != nil { | ||
// Must not happen in normal circumstances. | ||
panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err)) | ||
} | ||
|
||
return p.pool | ||
err = conn.Handshake() | ||
if err != nil { | ||
return nil, errors.WithDeferred(err, conn.Close()) | ||
} | ||
|
||
return conn, nil | ||
} | ||
|
||
// isCriticalTCP returns true if err isn't an expected error in terms of closing | ||
// the TCP connection. | ||
func isCriticalTCP(err error) (ok bool) { | ||
var netErr net.Error | ||
if errors.As(err, &netErr) && netErr.Timeout() { | ||
return false | ||
} | ||
|
||
switch { | ||
case | ||
errors.Is(err, io.EOF), | ||
errors.Is(err, net.ErrClosed), | ||
errors.Is(err, os.ErrDeadlineExceeded), | ||
isConnBroken(err): | ||
return false | ||
default: | ||
return true | ||
} | ||
} |
Oops, something went wrong.