Skip to content

Commit

Permalink
dnsforward: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Apr 9, 2024
1 parent b27547e commit 08bb7d4
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 99 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -27,7 +27,14 @@ NOTE: Add new changes BELOW THIS COMMENT.

- Support for comments in the ipset file ([#5345]).

### Fixed

- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
incorrectly considered invalid when specified for private RDNS upstream
servers ([#6854]).

[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854

<!--
NOTE: Add new changes ABOVE THIS COMMENT.
Expand Down
53 changes: 53 additions & 0 deletions internal/dnsforward/beforerequest.go
@@ -0,0 +1,53 @@
package dnsforward

import (
"encoding/binary"
"fmt"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)

// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)

// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(e.burkov): Write tests.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return fmt.Errorf("getting clientid: %w", err)
}

blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}

if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)

return s.preBlockedResponse(pctx)
}
}

if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}

return nil
}
43 changes: 0 additions & 43 deletions internal/dnsforward/filter.go
@@ -1,60 +1,17 @@
package dnsforward

import (
"encoding/binary"
"fmt"
"net"
"slices"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)

// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)

// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return fmt.Errorf("getting clientid: %w", err)
}

blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}

if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)

return s.preBlockedResponse(pctx)
}
}

if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}

return nil
}

// clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
Expand Down
113 changes: 59 additions & 54 deletions internal/dnsforward/http.go
Expand Up @@ -261,48 +261,6 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
}
}

// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}

var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()

for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}

var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}

if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}

return nil
}

// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}

_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}

return nil
}

// validate returns an error if any field of req is invalid.
//
// TODO(s.chzhen): Parse, don't validate.
Expand Down Expand Up @@ -342,23 +300,68 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil
}

// checkUpstreams returns an error if lines can't be parsed as an upstream
// configuration. If privateNets is not nil, it also checks that the domain
// specifications are strictly ARPA domains containing the prefixes within the
// set.
func checkUpstreams(lines []string, section string, privateNets netutil.SubnetSet) (err error) {
defer func() { err = errors.Annotate(err, "%s servers: %w", section) }()

uc, err := proxy.ParseUpstreamsConfig(lines, &upstream.Options{})
if err == nil {
defer func() { err = errors.WithDeferred(err, uc.Close()) }()

if privateNets != nil {
err = proxy.ValidatePrivateConfig(uc, privateNets)
}
}

return err
}

// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}

var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()

for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}

var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}

if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}

return nil
}

// validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
err = checkUpstreams(*req.Upstreams, "upstream", nil)
if err != nil {
return fmt.Errorf("upstream servers: %w", err)
// Don't wrap the error since it's informative enough as is.
return err
}
}

if req.LocalPTRUpstreams != nil {
var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig(*req.LocalPTRUpstreams, &upstream.Options{})
if err == nil {
err = proxy.ValidatePrivateConfig(uc, privateNets)
}
err = checkUpstreams(*req.LocalPTRUpstreams, "private upstream", privateNets)
if err != nil {
return fmt.Errorf("private upstream servers: %w", err)
// Don't wrap the error since it's informative enough as is.
return err
}
}

Expand All @@ -368,10 +371,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err
}

err = req.checkFallbacks()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
if req.Fallbacks != nil {
err = checkUpstreams(*req.Fallbacks, "fallback", nil)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions internal/dnsforward/msg.go
Expand Up @@ -12,8 +12,8 @@ import (
"github.com/miekg/dns"
)

// TODO(e.burkov): Call all the other methods by a [proxy.MessageConstructor]
// template.
// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// template. Also extract all the methods to a separate entity.

// reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
Expand Down

0 comments on commit 08bb7d4

Please sign in to comment.