Skip to content

Commit

Permalink
dnsforward: fix orig resp
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Sep 8, 2023
1 parent 5f696da commit 3534f66
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ NOTE: Add new changes BELOW THIS COMMENT.

### Fixed

- Incorrect original answer when a response is filtered ([#6183]).
- Empty or default Safe Browsing and Parental Control settings ([#6181]).
- Various UI issues.

[#6183]: https://github.com/AdguardTeam/AdGuardHome/issues/6183
[#6181]: https://github.com/AdguardTeam/AdGuardHome/issues/6181

<!--
Expand Down
47 changes: 25 additions & 22 deletions internal/dnsforward/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
Expand Down Expand Up @@ -140,36 +141,37 @@ func (s *Server) filterRewritten(

// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
r *filtering.Result,
err error,
) {
func (s *Server) checkHostRules(
host string,
rrtype rules.RRType,
setts *filtering.Settings,
) (r *filtering.Result, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()

var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
res, err := s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil {
return nil, err
}

return &res, err
}

// filterDNSResponse checks each resource record of the response's answer
// section from pctx and returns a non-nil res if at least one of canonical
// names or IP addresses in it matches the filtering rules.
func (s *Server) filterDNSResponse(
pctx *proxy.DNSContext,
setts *filtering.Settings,
) (res *filtering.Result, err error) {
// filterDNSResponse checks each resource record of answer section of
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
setts := dctx.setts
if !setts.FilteringEnabled {
return nil, nil
return nil
}

for _, a := range pctx.Res.Answer {
var res *filtering.Result
pctx := dctx.proxyCtx
for i, a := range pctx.Res.Answer {
host := ""
var rrtype uint16
var rrtype rules.RRType
switch a := a.(type) {
case *dns.CNAME:
host = strings.TrimSuffix(a.Target, ".")
Expand All @@ -195,18 +197,19 @@ func (s *Server) filterDNSResponse(
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)

if err != nil {
return nil, err
} else if res == nil {
continue
} else if res.IsFiltered {
return fmt.Errorf("filtering answer at index %d: %w", i, err)
} else if res != nil && res.IsFiltered {
dctx.result = res
dctx.origResp = pctx.Res
pctx.Res = s.genDNSFilterMessage(pctx, res)

log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)

return res, nil
break
}
}

return nil, nil
return nil
}

// removeIPv6Hints deletes IPv6 hints from RR values.
Expand Down
22 changes: 15 additions & 7 deletions internal/dnsforward/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,26 +328,34 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
}

res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
})
require.NoError(t, rErr)
dctx := &dnsContext{
proxyCtx: pctx,
setts: &filtering.Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
},
}

fltErr := s.filterDNSResponse(dctx)
require.NoError(t, fltErr)

res := dctx.result
if tc.wantRule == "" {
assert.Nil(t, res)

return
}

want := &filtering.Result{
wantResult := &filtering.Result{
IsFiltered: true,
Reason: filtering.FilteredBlockList,
Rules: []*filtering.ResultRule{{
Text: tc.wantRule,
}},
}
assert.Equal(t, want, res)

assert.Equal(t, wantResult, res)
assert.Equal(t, resp, dctx.origResp)
})
}
}
Expand Down
21 changes: 8 additions & 13 deletions internal/dnsforward/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,11 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
}

// Apply filtering logic
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req")

if ctx.proxyCtx.Res != nil {
if dctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
}
Expand All @@ -684,8 +684,8 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
defer s.serverLock.RUnlock()

var err error
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
dctx.err = err

return resultCodeError
}
Expand Down Expand Up @@ -857,7 +857,6 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
log.Debug("dnsforward: started processing filtering after resp")
defer log.Debug("dnsforward: finished processing filtering after resp")

pctx := dctx.proxyCtx
switch res := dctx.result; res.Reason {
case filtering.NotFilteredAllowList:
return resultCodeSuccess
Expand All @@ -871,6 +870,7 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
}

pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
Expand All @@ -880,31 +880,26 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)

return resultCodeSuccess
default:
return s.filterAfterResponse(dctx, pctx)
return s.filterAfterResponse(dctx)
}
}

// filterAfterResponse returns the result of filtering the response that wasn't
// explicitly allowed or rewritten.
func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (res resultCode) {
func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
if !dctx.protectionEnabled || !dctx.responseFromUpstream {
return resultCodeSuccess
}

result, err := s.filterDNSResponse(pctx, dctx.setts)
err := s.filterDNSResponse(dctx)
if err != nil {
dctx.err = err

return resultCodeError
}

if result != nil {
dctx.result = result
dctx.origResp = pctx.Res
}

return resultCodeSuccess
}

0 comments on commit 3534f66

Please sign in to comment.