Skip to content

Commit

Permalink
Pull request: fix ALPN
Browse files Browse the repository at this point in the history
Merge in DNS/dnsproxy from fix-alpn to master

Updates AdguardTeam/AdGuardHome#2681.

* commit '1beaef57054915a65636da4bda6157a1ec3d9acc':
  upstream: imp docs
  upstream: split host and port more carefully
  upstream: imp types, fix ALPN bug
  • Loading branch information
EugeneOne1 committed Mar 3, 2021
2 parents d3b053e + 1beaef5 commit 200e1aa
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 60 deletions.
50 changes: 18 additions & 32 deletions upstream/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"

Expand All @@ -30,7 +29,7 @@ var RootCAs *x509.CertPool
var CipherSuites []uint16

type bootstrapper struct {
address string // in form of "tls://one.one.one.one:853"
URL *url.URL
resolvers []*Resolver // list of Resolvers to use to resolve hostname, if necessary
dialContext dialHandler // specifies the dial function for creating unencrypted TCP connections.
resolvedConfig *tls.Config
Expand All @@ -46,11 +45,11 @@ type bootstrapper struct {
// newBootstrapperResolved creates a new bootstrapper that already contains resolved config.
// This can be done only in the case when we already know the resolver IP address.
// options -- Upstream customization options
func newBootstrapperResolved(address string, options Options) (*bootstrapper, error) {
func newBootstrapperResolved(upsURL *url.URL, options Options) (*bootstrapper, error) {
// get a host without port
host, port, err := getAddressHostPort(address)
host, port, err := net.SplitHostPort(upsURL.Host)
if err != nil {
return nil, fmt.Errorf("bootstrapper requires port in address %s", address)
return nil, fmt.Errorf("bootstrapper requires port in address %s", upsURL.String())
}

var resolverAddresses []string
Expand All @@ -60,7 +59,7 @@ func newBootstrapperResolved(address string, options Options) (*bootstrapper, er
}

b := &bootstrapper{
address: address,
URL: upsURL,
options: options,
}
b.dialContext = b.createDialContext(resolverAddresses)
Expand All @@ -72,7 +71,7 @@ func newBootstrapperResolved(address string, options Options) (*bootstrapper, er
// newBootstrapper initializes a new bootstrapper instance
// address -- original resolver address string (i.e. tls://one.one.one.one:853)
// options -- Upstream customization options
func newBootstrapper(address string, options Options) (*bootstrapper, error) {
func newBootstrapper(address *url.URL, options Options) (*bootstrapper, error) {
resolvers := []*Resolver{}
if len(options.Bootstrap) != 0 {
// Create a list of resolvers for parallel lookup
Expand All @@ -90,7 +89,7 @@ func newBootstrapper(address string, options Options) (*bootstrapper, error) {
}

return &bootstrapper{
address: address,
URL: address,
resolvers: resolvers,
options: options,
}, nil
Expand All @@ -113,11 +112,11 @@ func (n *bootstrapper) get() (*tls.Config, dialHandler, error) {
//

// get a host without port
host, port, err := getAddressHostPort(n.address)
addr := n.URL
host, port, err := net.SplitHostPort(addr.Host)
if err != nil {
addr := n.address
n.RUnlock()
return nil, nil, fmt.Errorf("bootstrapper requires port in address %s", addr)
return nil, nil, fmt.Errorf("bootstrapper requires port in address %s", addr.String())
}

// if n.address's host is an IP, just use it right away
Expand Down Expand Up @@ -191,8 +190,14 @@ func (n *bootstrapper) createTLSConfig(host string) *tls.Config {
VerifyPeerCertificate: n.options.VerifyServerCertificate,
}

tlsConfig.NextProtos = []string{
"http/1.1", http2.NextProtoTLS, NextProtoDQ,
// The supported application level protocols should be specified only
// for DNS-over-HTTPS and DNS-over-QUIC connections.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2681.
if n.URL.Scheme != "tls" {
tlsConfig.NextProtos = []string{
"http/1.1", http2.NextProtoTLS, NextProtoDQ,
}
}

return tlsConfig
Expand Down Expand Up @@ -230,22 +235,3 @@ func (n *bootstrapper) createDialContext(addresses []string) (dialContext dialHa
}
return
}

// getAddressHostPort splits resolver address into host and port
// returns host, port
func getAddressHostPort(address string) (string, string, error) {
justHostPort := address
if strings.Contains(address, "://") {
parsedURL, err := url.Parse(address)
if err != nil {
return "", "", errorx.Decorate(err, "failed to parse %s", address)
}

justHostPort = parsedURL.Host
}

// convert host to IP if necessary, we know that it's scheme://hostname:port/

// get a host without port
return net.SplitHostPort(justHostPort)
}
28 changes: 15 additions & 13 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func AddressToUpstream(address string, options Options) (Upstream, error) {

// urlToBoot creates an instance of the bootstrapper with the specified options
// options -- Upstream customization options
func urlToBoot(resolverURL string, opts Options) (*bootstrapper, error) {
func urlToBoot(resolverURL *url.URL, opts Options) (*bootstrapper, error) {
if len(opts.ServerIPAddrs) == 0 {
return newBootstrapper(resolverURL, opts)
}
Expand All @@ -107,20 +107,23 @@ func urlToBoot(resolverURL string, opts Options) (*bootstrapper, error) {
func urlToUpstream(upstreamURL *url.URL, opts Options) (Upstream, error) {
switch upstreamURL.Scheme {
case "sdns":
return stampToUpstream(upstreamURL.String(), opts)
return stampToUpstream(upstreamURL, opts)

case "dns":
return &plainDNS{address: getHostWithPort(upstreamURL, "53"), timeout: opts.Timeout}, nil

case "tcp":
return &plainDNS{address: getHostWithPort(upstreamURL, "53"), timeout: opts.Timeout, preferTCP: true}, nil

case "quic":
if upstreamURL.Port() == "" {
// https://tools.ietf.org/html/draft-ietf-dprive-dnsoquic-00#section-8.2.1
// Early experiments MAY use port 784. This port is marked in the IANA
// registry as unassigned.
upstreamURL.Host += ":784"
}
resolverURL := upstreamURL.String()
b, err := urlToBoot(resolverURL, opts)

b, err := urlToBoot(upstreamURL, opts)
if err != nil {
return nil, errorx.Decorate(err, "couldn't create quic bootstrapper")
}
Expand All @@ -131,8 +134,8 @@ func urlToUpstream(upstreamURL *url.URL, opts Options) (Upstream, error) {
if upstreamURL.Port() == "" {
upstreamURL.Host += ":853"
}
resolverURL := upstreamURL.String()
b, err := urlToBoot(resolverURL, opts)

b, err := urlToBoot(upstreamURL, opts)
if err != nil {
return nil, errorx.Decorate(err, "couldn't create tls bootstrapper")
}
Expand All @@ -144,8 +147,7 @@ func urlToUpstream(upstreamURL *url.URL, opts Options) (Upstream, error) {
upstreamURL.Host += ":443"
}

resolverURL := upstreamURL.String()
b, err := urlToBoot(resolverURL, opts)
b, err := urlToBoot(upstreamURL, opts)
if err != nil {
return nil, errorx.Decorate(err, "couldn't create tls bootstrapper")
}
Expand All @@ -159,10 +161,10 @@ func urlToUpstream(upstreamURL *url.URL, opts Options) (Upstream, error) {

// stampToUpstream converts a DNS stamp to an Upstream
// options -- Upstream customization options
func stampToUpstream(address string, opts Options) (Upstream, error) {
stamp, err := dnsstamps.NewServerStampFromString(address)
func stampToUpstream(upsURL *url.URL, opts Options) (Upstream, error) {
stamp, err := dnsstamps.NewServerStampFromString(upsURL.String())
if err != nil {
return nil, errorx.Decorate(err, "failed to parse %s", address)
return nil, errorx.Decorate(err, "failed to parse %s", upsURL)
}

if stamp.ServerAddrStr != "" {
Expand All @@ -183,7 +185,7 @@ func stampToUpstream(address string, opts Options) (Upstream, error) {
case dnsstamps.StampProtoTypePlain:
return &plainDNS{address: stamp.ServerAddrStr, timeout: opts.Timeout}, nil
case dnsstamps.StampProtoTypeDNSCrypt:
b, err := newBootstrapper(address, opts)
b, err := newBootstrapper(upsURL, opts)
if err != nil {
return nil, fmt.Errorf("bootstrap server parse: %s", err)
}
Expand All @@ -196,7 +198,7 @@ func stampToUpstream(address string, opts Options) (Upstream, error) {
return AddressToUpstream(fmt.Sprintf("tls://%s", stamp.ProviderName), opts)
}

return nil, fmt.Errorf("unsupported protocol %v in %s", stamp.Proto, address)
return nil, fmt.Errorf("unsupported protocol %v in %s", stamp.Proto, upsURL)
}

// getHostWithPort is a helper function that appends port if needed
Expand Down
5 changes: 2 additions & 3 deletions upstream/upstream_dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type dnsCrypt struct {
sync.RWMutex // protects DNSCrypt client
}

func (p *dnsCrypt) Address() string { return p.boot.address }
func (p *dnsCrypt) Address() string { return p.boot.URL.String() }

func (p *dnsCrypt) Exchange(m *dns.Msg) (*dns.Msg, error) {
reply, err := p.exchangeDNSCrypt(m)
Expand Down Expand Up @@ -60,8 +60,7 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (*dns.Msg, error) {

// Using "udp" for DNSCrypt upstreams by default
client = &dnscrypt.Client{Timeout: p.boot.options.Timeout}
ri, err := client.Dial(p.boot.address)

ri, err := client.Dial(p.Address())
if err != nil {
p.Unlock()
return nil, errorx.Decorate(err, "failed to fetch certificate info from %s", p.Address())
Expand Down
16 changes: 8 additions & 8 deletions upstream/upstream_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type dnsOverHTTPS struct {
client *http.Client
}

func (p *dnsOverHTTPS) Address() string { return p.boot.address }
func (p *dnsOverHTTPS) Address() string { return p.boot.URL.String() }

func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
client, err := p.getClient()
Expand All @@ -57,10 +57,10 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(m *dns.Msg, client *http.Client) (*dn

// It appears, that GET requests are more memory-efficient with Golang
// implementation of HTTP/2.
requestURL := p.boot.address + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
requestURL := p.Address() + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, errorx.Decorate(err, "couldn't create a HTTP request to %s", p.boot.address)
return nil, errorx.Decorate(err, "couldn't create a HTTP request to %s", p.boot.URL)
}
req.Header.Set("Accept", "application/dns-message")

Expand All @@ -69,20 +69,20 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(m *dns.Msg, client *http.Client) (*dn
defer resp.Body.Close()
}
if err != nil {
return nil, errorx.Decorate(err, "couldn't do a GET request to '%s'", p.boot.address)
return nil, errorx.Decorate(err, "couldn't do a GET request to '%s'", p.boot.URL)
}

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errorx.Decorate(err, "couldn't read body contents for '%s'", p.boot.address)
return nil, errorx.Decorate(err, "couldn't read body contents for '%s'", p.boot.URL)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.boot.address)
return nil, fmt.Errorf("got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.boot.URL)
}
response := dns.Msg{}
err = response.Unpack(body)
if err != nil {
return nil, errorx.Decorate(err, "couldn't unpack DNS response from '%s': body is %s", p.boot.address, string(body))
return nil, errorx.Decorate(err, "couldn't unpack DNS response from '%s': body is %s", p.boot.URL, string(body))
}
if err == nil && response.Id != m.Id {
err = dns.ErrId
Expand Down Expand Up @@ -135,7 +135,7 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
func (p *dnsOverHTTPS) createTransport() (*http.Transport, error) {
tlsConfig, dialContext, err := p.boot.get()
if err != nil {
return nil, errorx.Decorate(err, "couldn't bootstrap %s", p.boot.address)
return nil, errorx.Decorate(err, "couldn't bootstrap %s", p.boot.URL)
}

transport := &http.Transport{
Expand Down
2 changes: 1 addition & 1 deletion upstream/upstream_dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type dnsOverTLS struct {
sync.RWMutex // protects pool
}

func (p *dnsOverTLS) Address() string { return p.boot.address }
func (p *dnsOverTLS) Address() string { return p.boot.URL.String() }

func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
var pool *TLSPool
Expand Down
5 changes: 2 additions & 3 deletions upstream/upstream_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type dnsOverQUIC struct {
sync.RWMutex // protects session and bytesPool
}

func (p *dnsOverQUIC) Address() string { return p.boot.address }
func (p *dnsOverQUIC) Address() string { return p.boot.URL.String() }

func (p *dnsOverQUIC) Exchange(m *dns.Msg) (*dns.Msg, error) {
session, err := p.getSession(true)
Expand Down Expand Up @@ -55,8 +55,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (*dns.Msg, error) {
_ = stream.Close()

pool := p.getBytesPool()
var respBuf []byte
respBuf = pool.Get().([]byte)
respBuf := pool.Get().([]byte)

// Linter says that the argument needs to be pointer-like
// But it's already pointer-like
Expand Down

0 comments on commit 200e1aa

Please sign in to comment.