Skip to content

Commit

Permalink
aghnet: imp host validation for system resolvers
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Apr 28, 2021
1 parent 5b80811 commit 8154797
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
12 changes: 8 additions & 4 deletions internal/aghnet/systemresolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ type SystemResolvers interface {
}

const (
// fakeDialErr is an error which dialFunc is expected to return.
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed agherr.Error = "the passed string is not a valid IP address"

// badAddrPassedErr is returned when dialFunc can't parse an IP address.
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial agherr.Error = "this error signals the successful dialFunc work"

// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat agherr.Error = "unexpected host format"
)

// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
Expand Down
35 changes: 30 additions & 5 deletions internal/aghnet/systemresolvers_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -35,7 +36,7 @@ func (sr *systemResolvers) refresh() (err error) {

_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
return nil
}

Expand All @@ -58,6 +59,29 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
return s
}

// validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) {
defer agherr.Annotate("parsing %q: %w", &err, host)

var ipStr string
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
ipStr = host
case 2:
// Remove the zone and check the IP address part.
ipStr = parts[0]
default:
return errUnexpectedHostFormat
}

if net.ParseIP(ipStr) == nil {
return errBadAddrPassed
}

return nil
}

// dialFunc gets the resolver's address and puts it into internal cache.
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
Expand All @@ -66,19 +90,20 @@ func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net
if err != nil {
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to
// allow unwrapping of the real error.
return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr)
return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
}

if net.ParseIP(host) == nil {
return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr)
err = validateDialedHost(host)
if err != nil {
return nil, fmt.Errorf("validating dialed host: %w", err)
}

sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()

sr.addrs.Add(host)

return nil, fakeDialErr
return nil, errFakeDial
}

func (sr *systemResolvers) Get() (rs []string) {
Expand Down
20 changes: 16 additions & 4 deletions internal/aghnet/systemresolvers_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,33 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil)

testCases := []struct {
want error
name string
address string
want error
}{{
want: errFakeDial,
name: "valid",
address: "127.0.0.1",
want: fakeDialErr,
}, {
want: errFakeDial,
name: "valid_ipv6_port",
address: "[::1]:53",
}, {
want: errFakeDial,
name: "valid_ipv6_zone_port",
address: "[::1%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_split_host",
address: "127.0.0.1::123",
want: badAddrPassedErr,
}, {
want: errUnexpectedHostFormat,
name: "invalid_ipv6_zone_port",
address: "[::1%%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_parse_ip",
address: "not-ip",
want: badAddrPassedErr,
}}

for _, tc := range testCases {
Expand Down

0 comments on commit 8154797

Please sign in to comment.