Skip to content

Commit

Permalink
feat: 🎸 changed the way upstreams are cusomized per-request
Browse files Browse the repository at this point in the history
This is necessary to properly implement this:
AdguardTeam/AdGuardHome#1539
  • Loading branch information
ameshkov committed May 13, 2020
1 parent b8a91c7 commit d8f4a77
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 357 deletions.
21 changes: 10 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,16 @@ func createProxyConfig(options Options) proxy.Config {

// Create the config
config := proxy.Config{
Upstreams: upstreamConfig.Upstreams,
DomainsReservedUpstreams: upstreamConfig.DomainReservedUpstreams,
Ratelimit: options.Ratelimit,
CacheEnabled: options.Cache,
CacheSizeBytes: options.CacheSizeBytes,
CacheMinTTL: options.CacheMinTTL,
CacheMaxTTL: options.CacheMaxTTL,
RefuseAny: options.RefuseAny,
AllServers: options.AllServers,
EnableEDNSClientSubnet: options.EnableEDNSSubnet,
FindFastestAddr: options.FastestAddress,
UpstreamConfig: &upstreamConfig,
Ratelimit: options.Ratelimit,
CacheEnabled: options.Cache,
CacheSizeBytes: options.CacheSizeBytes,
CacheMinTTL: options.CacheMinTTL,
CacheMaxTTL: options.CacheMaxTTL,
RefuseAny: options.RefuseAny,
AllServers: options.AllServers,
EnableEDNSClientSubnet: options.EnableEDNSSubnet,
FindFastestAddr: options.FastestAddress,
}

if options.EDNSAddr != "" {
Expand Down
2 changes: 1 addition & 1 deletion proxy/bogus_nxdomain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestBogusNXDomainTypeA(t *testing.T) {
dnsProxy.BogusNXDomain = []net.IP{net.ParseIP("4.3.2.1")}

u := testUpstream{}
dnsProxy.Upstreams = []upstream.Upstream{&u}
dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u}
err := dnsProxy.Start()
assert.Nil(t, err)

Expand Down
2 changes: 1 addition & 1 deletion proxy/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) {
dnsProxy.CacheMinTTL = 20
dnsProxy.CacheMaxTTL = 40
u := testUpstream{}
dnsProxy.Upstreams = []upstream.Upstream{&u}
dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u}
err := dnsProxy.Start()
if err != nil {
t.Fatalf("cannot start the DNS proxy: %s", err)
Expand Down
113 changes: 11 additions & 102 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@ package proxy
import (
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"time"

"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"

"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
)

// Config contains all the fields necessary for proxy configuration
Expand All @@ -35,11 +30,10 @@ type Config struct {
// Upstream DNS servers and their settings
// --

Upstreams []upstream.Upstream // list of upstreams
Fallbacks []upstream.Upstream // list of fallback resolvers (which will be used if regular upstream failed to answer)
AllServers bool // if true, parallel queries to all configured upstream servers are enabled
DomainsReservedUpstreams map[string][]upstream.Upstream // map of domains and lists of corresponding upstreams
FindFastestAddr bool // use Fastest Address algorithm
UpstreamConfig *UpstreamConfig // Upstream DNS servers configuration
Fallbacks []upstream.Upstream // list of fallback resolvers (which will be used if regular upstream failed to answer)
AllServers bool // if true, parallel queries to all configured upstream servers are enabled
FindFastestAddr bool // use Fastest Address algorithm

// BogusNXDomain - transforms responses that contain only given IP addresses into NXDOMAIN
// Similar to dnsmasq's "bogus-nxdomain"
Expand Down Expand Up @@ -88,95 +82,6 @@ type Config struct {
MaxGoroutines int // maximum number of goroutines processing the DNS requests (important for mobile)
}

// UpstreamConfig is a wrapper for list of default upstreams and map of reserved domains and corresponding upstreams
type UpstreamConfig struct {
Upstreams []upstream.Upstream // list of default upstreams
DomainReservedUpstreams map[string][]upstream.Upstream // map of reserved domains and lists of corresponding upstreams
}

// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid
// default upstream syntax: <upstreamString>
// reserved upstream syntax: [/domain1/../domainN/]<upstreamString>
// More specific domains take priority over less specific domains,
// To exclude more specific domains from reserved upstreams querying you should use the following syntax: [/domain1/../domainN/]#
// So the following config: ["[/host.com/]1.2.3.4", "[/www.host.com/]2.3.4.5", "[/maps.host.com/]#", "3.4.5.6"]
// will send queries for *.host.com to 1.2.3.4, except for *.www.host.com, which will go to 2.3.4.5 and *.maps.host.com,
// which will go to default server 3.4.5.6 with all other domains
func ParseUpstreamsConfig(upstreamConfig, bootstrapDNS []string, timeout time.Duration) (UpstreamConfig, error) {
return ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS, timeout, func(address string, opts upstream.Options) (upstream.Upstream, error) {
return upstream.AddressToUpstream(address, opts)
})
}

// AddressToUpstreamFunction is a type for a callback function which creates an upstream object
type AddressToUpstreamFunction func(address string, opts upstream.Options) (upstream.Upstream, error)

// ParseUpstreamsConfigEx is an extended version of ParseUpstreamsConfig() which has a custom callback function which creates an upstream object
func ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS []string, timeout time.Duration, addressToUpstreamFunction AddressToUpstreamFunction) (UpstreamConfig, error) {
upstreams := []upstream.Upstream{}
domainReservedUpstreams := map[string][]upstream.Upstream{}

if len(bootstrapDNS) > 0 {
for i, b := range bootstrapDNS {
log.Info("Bootstrap %d: %s", i, b)
}
}

for i, u := range upstreamConfig {
hosts := []string{}
if strings.HasPrefix(u, "[/") {
// split domains and upstream string
domainsAndUpstream := strings.Split(strings.TrimPrefix(u, "[/"), "/]")
if len(domainsAndUpstream) != 2 {
return UpstreamConfig{}, fmt.Errorf("wrong upstream specification: %s", u)
}

// split domains list
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
if host != "" {
if err := utils.IsValidHostname(host); err != nil {
return UpstreamConfig{}, err
}
hosts = append(hosts, strings.ToLower(host+"."))
} else {
// empty domain specification means `unqualified names only`
hosts = append(hosts, UnqualifiedNames)
}
}
u = domainsAndUpstream[1]
}

// # excludes more specific domain from reserved upstreams querying
if u == "#" && len(hosts) > 0 {
for _, host := range hosts {
domainReservedUpstreams[host] = nil
}
continue
}

// create an upstream
dnsUpstream, err := addressToUpstreamFunction(u, upstream.Options{Bootstrap: bootstrapDNS, Timeout: timeout})
if err != nil {
return UpstreamConfig{}, fmt.Errorf("cannot prepare the upstream %s (%s): %s", u, bootstrapDNS, err)
}

if len(hosts) > 0 {
for _, host := range hosts {
_, ok := domainReservedUpstreams[host]
if !ok {
domainReservedUpstreams[host] = []upstream.Upstream{}
}
domainReservedUpstreams[host] = append(domainReservedUpstreams[host], dnsUpstream)
}
log.Printf("Upstream %d: %s is reserved for next domains: %s", i, dnsUpstream.Address(), strings.Join(hosts, ", "))
} else {
log.Printf("Upstream %d: %s", i, dnsUpstream.Address())
upstreams = append(upstreams, dnsUpstream)
}
}
return UpstreamConfig{Upstreams: upstreams, DomainReservedUpstreams: domainReservedUpstreams}, nil
}

// validateConfig verifies that the supplied configuration is valid and returns an error if it's not
func (p *Proxy) validateConfig() error {
if p.started {
Expand All @@ -195,8 +100,12 @@ func (p *Proxy) validateConfig() error {
return errors.New("cannot create an HTTPS listener without TLS config")
}

if len(p.Upstreams) == 0 {
if len(p.DomainsReservedUpstreams) == 0 {
if p.UpstreamConfig == nil {
return errors.New("no default upstreams specified")
}

if len(p.UpstreamConfig.Upstreams) == 0 {
if len(p.UpstreamConfig.DomainReservedUpstreams) == 0 {
return errors.New("no upstreams specified")
}
return errors.New("no default upstreams specified")
Expand Down
4 changes: 2 additions & 2 deletions proxy/dns64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestProxyWithDNS64(t *testing.T) {

// Let's create test A request to ipv4OnlyHost and exchange it with test proxy
req := createHostTestMessage(ipv4OnlyHost)
resp, _, err := dnsProxy.exchange(req, dnsProxy.Upstreams)
resp, _, err := dnsProxy.exchange(req, dnsProxy.UpstreamConfig.Upstreams)
if err != nil {
t.Fatalf("Can not exchange test message for %s cause: %s", ipv4OnlyHost, err)
}
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestProxyWithDNS64(t *testing.T) {
func TestDNS64Race(t *testing.T) {
dnsProxy := createTestProxy(t, nil)
dnsProxy.nat64Prefix = prefix
dnsProxy.Upstreams = append(dnsProxy.Upstreams, dnsProxy.Upstreams[0])
dnsProxy.UpstreamConfig.Upstreams = append(dnsProxy.UpstreamConfig.Upstreams, dnsProxy.UpstreamConfig.Upstreams[0])

// Start listening
err := dnsProxy.Start()
Expand Down
3 changes: 2 additions & 1 deletion proxy/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ func TestLookupIPAddr(t *testing.T) {
if err != nil {
t.Fatalf("cannot prepare the upstream: %s", err)
}
p.Upstreams = append(upstreams, dnsUpstream)
p.UpstreamConfig = &UpstreamConfig{}
p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)

// Init the proxy
p.Init()
Expand Down
51 changes: 13 additions & 38 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package proxy
import (
"net"
"net/http"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -96,9 +95,10 @@ type DNSContext struct {
StartTime time.Time // processing start time
Upstream upstream.Upstream // upstream that resolved DNS request

// Upstream servers to use for this request
// CustomUpstreamConfig -- custom upstream servers configuration
// to use for this request only.
// If set, Resolve() uses it instead of default servers
Upstreams []upstream.Upstream
CustomUpstreamConfig *UpstreamConfig

ecsReqIP net.IP // ECS IP used in request
ecsReqMask uint8 // ECS mask used in request
Expand Down Expand Up @@ -258,11 +258,17 @@ func (p *Proxy) Resolve(d *DNSContext) error {
return nil
}

host := d.Req.Question[0].Name
var upstreams []upstream.Upstream

// Get custom upstreams first -- note that they might be empty
upstreams := d.Upstreams
if len(upstreams) == 0 {
// get upstreams for the specified hostname
upstreams = p.getUpstreamsForDomain(d.Req.Question[0].Name)
if d.CustomUpstreamConfig != nil {
upstreams = d.CustomUpstreamConfig.getUpstreamsForDomain(host)
}

// If nothing found in the custom upstreams, start using the default ones
if upstreams == nil {
upstreams = p.UpstreamConfig.getUpstreamsForDomain(host)
}

// execute the DNS request
Expand Down Expand Up @@ -308,37 +314,6 @@ func (p *Proxy) Resolve(d *DNSContext) error {
return err
}

// getUpstreamsForDomain looks for a domain in reserved domains map and returns a list of corresponding upstreams.
// returns default upstreams list if domain isn't found. More specific domains take priority over less specific domains.
// For example, map contains the following keys: host.com and www.host.com
// If we are looking for domain mail.host.com, this method will return value of host.com key
// If we are looking for domain www.host.com, this method will return value of www.host.com key
// If more specific domain value is nil, it means that domain was excluded and should be exchanged with default upstreams
func (p *Proxy) getUpstreamsForDomain(host string) []upstream.Upstream {
if len(p.DomainsReservedUpstreams) == 0 {
return p.Upstreams
}

dotsCount := strings.Count(host, ".")
if dotsCount < 2 {
return p.DomainsReservedUpstreams[UnqualifiedNames]
}

for i := 1; i <= dotsCount; i++ {
h := strings.SplitAfterN(host, ".", i)
name := h[i-1]
if u, ok := p.DomainsReservedUpstreams[strings.ToLower(name)]; ok {
if u == nil {
// domain was excluded from reserved upstreams querying
return p.Upstreams
}
return u
}
}

return p.Upstreams
}

// Set EDNS Client-Subnet data in DNS request
func (p *Proxy) processECS(d *DNSContext) {
d.ecsReqIP = nil
Expand Down
4 changes: 2 additions & 2 deletions proxy/proxy_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
// Get response from general or subnet cache
// Return TRUE if response is found in cache
func (p *Proxy) replyFromCache(d *DNSContext) bool {
if p.cache == nil || len(d.Upstreams) > 0 {
if p.cache == nil || d.CustomUpstreamConfig != nil {
// Do not use cache if:
// it is disabled
// the query is with custom upstreams
Expand Down Expand Up @@ -46,7 +46,7 @@ func (p *Proxy) replyFromCache(d *DNSContext) bool {

// Store response in general or subnet cache
func (p *Proxy) setInCache(d *DNSContext, resp *dns.Msg) {
if p.cache == nil || len(d.Upstreams) > 0 {
if p.cache == nil || d.CustomUpstreamConfig != nil {
// Do not use cache if:
// it is disabled
// the query is with custom upstreams
Expand Down
Loading

0 comments on commit d8f4a77

Please sign in to comment.