Skip to content

Commit

Permalink
dnsforward: move code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Apr 9, 2024
1 parent 08bb7d4 commit c834067
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 60 deletions.
60 changes: 60 additions & 0 deletions internal/dnsforward/beforerequest.go
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
Expand Down Expand Up @@ -51,3 +52,62 @@ func (s *Server) HandleBefore(

return nil
}

// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}

// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}

hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}

cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}

clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}

return clientID, nil
}

// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"

// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}

return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}
40 changes: 0 additions & 40 deletions internal/dnsforward/clientid.go
Expand Up @@ -110,46 +110,6 @@ type quicConnection interface {
ConnectionState() (cs quic.ConnectionState)
}

// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}

// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}

hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}

cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}

clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}

return clientID, nil
}

// clientServerName returns the TLS server name based on the protocol. For
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
// if there is one.
Expand Down
20 changes: 0 additions & 20 deletions internal/dnsforward/msg.go
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
Expand Down Expand Up @@ -340,25 +339,6 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp
}

// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"

// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}

return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}

// Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
return s.reply(req, dns.RcodeRefused)
Expand Down

0 comments on commit c834067

Please sign in to comment.