Skip to content

Commit

Permalink
Pull request 312: 6321 boot ttl vol.1
Browse files Browse the repository at this point in the history
Updates AdguardTeam/AdGuardHome#6321.

Squashed commit of the following:

commit fb8f98b
Merge: 329e3e3 a87a3df
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jan 11 14:14:06 2024 +0300

    Merge branch 'master' into 6321-boot-ttl-vol.1

commit 329e3e3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 9 23:27:41 2024 +0300

    all: imp docs

commit 91cc3f0
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Dec 29 17:59:36 2023 +0500

    all: imp code, rm redundant changes

commit f0df8c2
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Dec 28 19:20:07 2023 +0500

    all: imp code

commit c1fd087
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Dec 27 18:23:38 2023 +0300

    all: move code, use new types

commit e1d9405
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Dec 27 18:17:08 2023 +0300

    upstream: add separate test
  • Loading branch information
EugeneOne1 committed Jan 11, 2024
1 parent a87a3df commit 04571e6
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 161 deletions.
75 changes: 46 additions & 29 deletions internal/bootstrap/bootstrap.go
Expand Up @@ -14,20 +14,42 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
)

// Network is a network type for use in [Resolver]'s methods.
type Network = string

const (
// NetworkIP is a network type for both address families.
NetworkIP Network = "ip"

// NetworkIP4 is a network type for IPv4 address family.
NetworkIP4 Network = "ip4"

// NetworkIP6 is a network type for IPv6 address family.
NetworkIP6 Network = "ip6"

// NetworkTCP is a network type for TCP connections.
NetworkTCP Network = "tcp"

// NetworkUDP is a network type for UDP connections.
NetworkUDP Network = "udp"
)

// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server. It establishes the connection to the server
// specified at initialization and ignores the addr.
type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// specified at initialization and ignores the addr. network must be one of
// [NetworkTCP] or [NetworkUDP].
type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error)

// ResolveDialContext returns a DialHandler that uses addresses resolved from u
// using resolver. u must not be nil.
func ResolveDialContext(
u *url.URL,
timeout time.Duration,
resolver Resolver,
preferIPv6 bool,
r Resolver,
preferV6 bool,
) (h DialHandler, err error) {
defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }()

Expand All @@ -38,7 +60,7 @@ func ResolveDialContext(
return nil, err
}

if resolver == nil {
if r == nil {
return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers)
}

Expand All @@ -49,36 +71,28 @@ func ResolveDialContext(
defer cancel()
}

ips, err := resolver.LookupNetIP(ctx, "ip", host)
ips, err := r.LookupNetIP(ctx, NetworkIP, host)
if err != nil {
return nil, fmt.Errorf("resolving hostname: %w", err)
}

proxynetutil.SortNetIPAddrs(ips, preferIPv6)
if preferV6 {
slices.SortStableFunc(ips, proxynetutil.PreferIPv6)
} else {
slices.SortStableFunc(ips, proxynetutil.PreferIPv4)
}

addrs := make([]string, 0, len(ips))
for _, ip := range ips {
if !ip.IsValid() {
// All invalid addresses should be in the tail after sorting.
break
}

addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)).String())
addrs = append(addrs, netip.AddrPortFrom(ip, port).String())
}

return NewDialContext(timeout, addrs...), nil
}

// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection. At least a single addr should be specified.
//
// TODO(e.burkov): Consider using [Resolver] instead of
// [upstream.Options.Bootstrap] and [upstream.Options.ServerIPAddrs].
func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
dialer := &net.Dialer{
Timeout: timeout,
}

l := len(addrs)
if l == 0 {
log.Debug("bootstrap: no addresses to dial")
Expand All @@ -88,9 +102,11 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
}
}

// TODO(e.burkov): Check IPv6 preference here.
dialer := &net.Dialer{
Timeout: timeout,
}

return func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) {
var errs []error

// Return first succeeded connection. Note that we're using addrs
Expand All @@ -101,17 +117,18 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
start := time.Now()
conn, err = dialer.DialContext(ctx, network, addr)
elapsed := time.Since(start)
if err == nil {
log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed)
if err != nil {
log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err)
errs = append(errs, err)

return conn, nil
continue
}

log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err)
errs = append(errs, err)
log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed)

return conn, nil
}

// TODO(e.burkov): Use errors.Join in Go 1.20.
return nil, errors.List("all dialers failed", errs...)
return nil, errors.Join(errs...)
}
}
8 changes: 4 additions & 4 deletions internal/bootstrap/bootstrap_test.go
Expand Up @@ -87,7 +87,7 @@ func TestResolveDialContext(t *testing.T) {
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, "ip", network)
require.Equal(pt, bootstrap.NetworkIP, network)
require.Equal(pt, hostname, host)

return tc.addresses, nil
Expand All @@ -103,7 +103,7 @@ func TestResolveDialContext(t *testing.T) {
)
require.NoError(t, err)

conn, err := dialContext(context.Background(), "tcp", "")
conn, err := dialContext(context.Background(), bootstrap.NetworkTCP, "")
require.NoError(t, err)

expected, ok := testutil.RequireReceive(t, sig, testTimeout)
Expand All @@ -120,7 +120,7 @@ func TestResolveDialContext(t *testing.T) {
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, "ip", network)
require.Equal(pt, bootstrap.NetworkIP, network)
require.Equal(pt, hostname, host)

return nil, nil
Expand All @@ -135,7 +135,7 @@ func TestResolveDialContext(t *testing.T) {
)
require.NoError(t, err)

_, err = dialContext(context.Background(), "tcp", "")
_, err = dialContext(context.Background(), bootstrap.NetworkTCP, "")
testutil.AssertErrorMsg(t, "no addresses", err)
})

Expand Down
6 changes: 6 additions & 0 deletions internal/bootstrap/error.go
@@ -0,0 +1,6 @@
package bootstrap

import "github.com/AdguardTeam/golibs/errors"

// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"
66 changes: 56 additions & 10 deletions internal/bootstrap/resolver.go
Expand Up @@ -8,22 +8,21 @@ import (

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)

// Resolver resolves the hostnames to IP addresses.
// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver]
// from standard library also implements this interface.
type Resolver interface {
// LookupNetIP looks up the IP addresses for the given host. network must
// be one of "ip", "ip4" or "ip6". The response may be empty even if err is
// nil.
LookupNetIP(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
// LookupNetIP looks up the IP addresses for the given host. network should
// be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be
// empty even if err is nil. All the addrs must be valid.
LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error)
}

// type check
var _ Resolver = &net.Resolver{}

// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"

// ParallelResolver is a slice of resolvers that are queried concurrently. The
// first successful response is returned.
type ParallelResolver []Resolver
Expand All @@ -34,7 +33,7 @@ var _ Resolver = ParallelResolver(nil)
// LookupNetIP implements the [Resolver] interface for ParallelResolver.
func (r ParallelResolver) LookupNetIP(
ctx context.Context,
network string,
network Network,
host string,
) (addrs []netip.Addr, err error) {
resolversNum := len(r)
Expand All @@ -48,7 +47,7 @@ func (r ParallelResolver) LookupNetIP(
}

// Size of channel must accommodate results of lookups from all resolvers,
// sending into channel will be block otherwise.
// sending into channel will block otherwise.
ch := make(chan any, resolversNum)
for _, rslv := range r {
go lookupAsync(ctx, rslv, network, host, ch)
Expand Down Expand Up @@ -97,3 +96,50 @@ func lookup(ctx context.Context, r Resolver, network, host string) (addrs []neti

return addrs, err
}

// ConsequentResolver is a slice of resolvers that are queried in order until
// the first successful non-empty response, as opposed to just successful
// response requirement in [ParallelResolver].
type ConsequentResolver []Resolver

// type check
var _ Resolver = ConsequentResolver(nil)

// LookupNetIP implements the [Resolver] interface for ConsequentResolver.
func (resolvers ConsequentResolver) LookupNetIP(
ctx context.Context,
network Network,
host string,
) (addrs []netip.Addr, err error) {
if len(resolvers) == 0 {
return nil, ErrNoResolvers
}

var errs []error
for _, r := range resolvers {
addrs, err = r.LookupNetIP(ctx, network, host)
if err == nil && len(addrs) > 0 {
return addrs, nil
}

errs = append(errs, err)
}

return nil, errors.Join(errs...)
}

// StaticResolver is a resolver which always responds with an underlying slice
// of IP addresses regardless of host and network.
type StaticResolver []netip.Addr

// type check
var _ Resolver = StaticResolver(nil)

// LookupNetIP implements the [Resolver] interface for StaticResolver.
func (r StaticResolver) LookupNetIP(
_ context.Context,
_ Network,
_ string,
) (addrs []netip.Addr, err error) {
return slices.Clone(r), nil
}
36 changes: 36 additions & 0 deletions internal/netutil/netutil.go
Expand Up @@ -12,6 +12,42 @@ import (
"golang.org/x/exp/slices"
)

// PreferIPv4 compares two addresses, preferring IPv4 addresses over IPv6 ones.
// Invalid addresses are sorted near the end.
func PreferIPv4(a, b netip.Addr) (res int) {
if !a.IsValid() {
return 1
} else if !b.IsValid() {
return -1
}

if aIs4 := a.Is4(); aIs4 == b.Is4() {
return a.Compare(b)
} else if aIs4 {
return -1
}

return 1
}

// PreferIPv6 compares two addresses, preferring IPv6 addresses over IPv4 ones.
// Invalid addresses are sorted near the end.
func PreferIPv6(a, b netip.Addr) (res int) {
if !a.IsValid() {
return 1
} else if !b.IsValid() {
return -1
}

if aIs6 := a.Is6(); aIs6 == b.Is6() {
return a.Compare(b)
} else if aIs6 {
return -1
}

return 1
}

// SortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end. Zones are ignored.
func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) {
Expand Down
4 changes: 2 additions & 2 deletions proxy/proxy_test.go
Expand Up @@ -329,7 +329,7 @@ func TestExchangeWithReservedDomains(t *testing.T) {
upstreams,
&upstream.Options{
InsecureSkipVerify: false,
Bootstrap: googleRslv,
Bootstrap: upstream.NewCachingResolver(googleRslv),
Timeout: 1 * time.Second,
},
)
Expand Down Expand Up @@ -412,7 +412,7 @@ func TestOneByOneUpstreamsExchange(t *testing.T) {
u, err = upstream.AddressToUpstream(
line,
&upstream.Options{
Bootstrap: googleRslv,
Bootstrap: upstream.NewCachingResolver(googleRslv),
Timeout: timeOut,
},
)
Expand Down

0 comments on commit 04571e6

Please sign in to comment.