Skip to content

Commit

Permalink
Merge pull request #93 in DNS/adguard-dns from fix/414 to master
Browse files Browse the repository at this point in the history
* commit '914eb612cd0da015b98c151b7ac603fb4126a2c3':
  Add bootstrap DNS to readme
  Fix review comments
  Close test upstream
  Added bootstrap DNS to the config file DNS healthcheck now uses the upstream package methods
  goimports files
  Added CoreDNS plugin setup and replaced forward
  Added factory method for creating DNS upstreams
  Added health-check method
  Added persistent connections cache
  Upstream plugin prototype
  • Loading branch information
ameshkov committed Nov 6, 2018
2 parents 2449075 + 914eb61 commit 4a357f1
Show file tree
Hide file tree
Showing 13 changed files with 955 additions and 114 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
* `parental_enabled` — Parental control-based DNS requests filtering
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
* `upstream_dns` — List of upstream DNS servers
* `filters` — List of filters, each filter has the following values:
* `ID` - filter ID (must be unique)
* `url` — URL pointing to the filter contents (filtering rules)
* `enabled` — Current filter's status (enabled/disabled)
* `user_rules` — User-specified filtering rules
Expand Down
4 changes: 3 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type coreDNSConfig struct {
Pprof string `yaml:"-"`
Cache string `yaml:"-"`
Prometheus string `yaml:"-"`
BootstrapDNS string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"`
}

Expand Down Expand Up @@ -100,6 +101,7 @@ var config = configuration{
SafeBrowsingEnabled: false,
BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
BootstrapDNS: "8.8.8.8:53",
UpstreamDNS: defaultDNS,
Cache: "cache",
Prometheus: "prometheus :9153",
Expand Down Expand Up @@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
hosts {
fallthrough
}
{{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}}
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
{{.Cache}}
{{.Prometheus}}
}
Expand Down
147 changes: 34 additions & 113 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
Expand All @@ -15,8 +14,9 @@ import (
"strings"
"time"

"github.com/AdguardTeam/AdGuardHome/upstream"

corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4"
)

Expand Down Expand Up @@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
"protection_enabled": config.CoreDNS.ProtectionEnabled,
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
"running": isRunning(),
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
"upstream_dns": config.CoreDNS.UpstreamDNS,
"version": VersionString,
}
Expand Down Expand Up @@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadRequest)
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
// if empty body -- user is asking for default servers
hosts, err := sanitiseDNSServers(string(body))
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
return
}
hosts := strings.Fields(string(body))

if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS
} else {
Expand All @@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {

err = writeAllConfigs()
if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
tellCoreDNSToReload()
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}

func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, 400)
return
}
hosts := strings.Fields(string(body))

if len(hosts) == 0 {
errortext := fmt.Sprintf("No servers specified")
log.Println(errortext)
http.Error(w, errortext, http.StatusBadRequest)
errorText := fmt.Sprintf("No servers specified")
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}

Expand All @@ -198,118 +196,41 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {

jsonVal, err := json.Marshal(result)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}

func checkDNS(input string) error {
input, err := sanitizeDNSServer(input)

u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)

if err != nil {
return err
}
defer u.Close()

req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}

prefix, host := splitDNSServerPrefixServer(input)
alive, err := upstream.IsAlive(u)

c := dns.Client{
Timeout: time.Minute,
}
switch prefix {
case "tls://":
c.Net = "tcp-tls"
}

resp, rtt, err := c.Exchange(&req, host)
if err != nil {
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
}
trace("exchange with %s took %v", input, rtt)
if len(resp.Answer) != 1 {
return fmt.Errorf("DNS server %s returned wrong answer", input)
}
if t, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
}
}

return nil
}

func sanitiseDNSServers(input string) ([]string, error) {
fields := strings.Fields(input)
hosts := make([]string, 0)
for _, field := range fields {
sanitized, err := sanitizeDNSServer(field)
if err != nil {
return hosts, err
}
hosts = append(hosts, sanitized)
if !alive {
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
}
return hosts, nil
}

func getDNSServerPrefix(input string) string {
prefix := ""
switch {
case strings.HasPrefix(input, "dns://"):
prefix = "dns://"
case strings.HasPrefix(input, "tls://"):
prefix = "tls://"
}
return prefix
}

func splitDNSServerPrefixServer(input string) (string, string) {
prefix := getDNSServerPrefix(input)
host := strings.TrimPrefix(input, prefix)
return prefix, host
}

func sanitizeDNSServer(input string) (string, error) {
prefix, host := splitDNSServerPrefixServer(input)
host = appendPortIfMissing(prefix, host)
{
h, _, err := net.SplitHostPort(host)
if err != nil {
return "", err
}
ip := net.ParseIP(h)
if ip == nil {
return "", fmt.Errorf("invalid DNS server field: %s", h)
}
}
return prefix + host, nil
}

func appendPortIfMissing(prefix, input string) string {
port := "53"
switch prefix {
case "tls://":
port = "853"
}
_, _, err := net.SplitHostPort(input)
if err == nil {
return input
}
return net.JoinHostPort(input, port)
return nil
}

//noinspection GoUnusedParameter
Expand Down
2 changes: 2 additions & 0 deletions coredns.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync" // Include all plugins.

_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
_ "github.com/AdguardTeam/AdGuardHome/upstream"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/coremain"
_ "github.com/coredns/coredns/plugin/auto"
Expand Down Expand Up @@ -79,6 +80,7 @@ var directives = []string{
"loop",
"forward",
"proxy",
"upstream",
"erratic",
"whoami",
"on",
Expand Down
1 change: 1 addition & 0 deletions openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ paths:
protection_enabled: true
querylog_enabled: true
running: true
bootstrap_dns: 8.8.8.8:53
upstream_dns:
- 1.1.1.1
- 1.0.0.1
Expand Down
109 changes: 109 additions & 0 deletions upstream/dns_upstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package upstream

import (
"crypto/tls"
"time"

"github.com/miekg/dns"
"golang.org/x/net/context"
)

// DnsUpstream is a very simple upstream implementation for plain DNS
type DnsUpstream struct {
endpoint string // IP:port
timeout time.Duration // Max read and write timeout
proto string // Protocol (tcp, tcp-tls, or udp)
transport *Transport // Persistent connections cache
}

// NewDnsUpstream creates a new DNS upstream
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {

u := &DnsUpstream{
endpoint: endpoint,
timeout: defaultTimeout,
proto: proto,
}

var tlsConfig *tls.Config

if proto == "tcp-tls" {
tlsConfig = new(tls.Config)
tlsConfig.ServerName = tlsServerName
}

// Initialize the connections cache
u.transport = NewTransport(endpoint)
u.transport.tlsConfig = tlsConfig
u.transport.Start()

return u, nil
}

// Exchange provides an implementation for the Upstream interface
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {

resp, err := u.exchange(u.proto, query)

// Retry over TCP if response is truncated
if err == dns.ErrTruncated && u.proto == "udp" {
resp, err = u.exchange("tcp", query)
} else if err == dns.ErrTruncated && resp != nil {
// Reassemble something to be sent to client
m := new(dns.Msg)
m.SetReply(query)
m.Truncated = true
m.Authoritative = true
m.Rcode = dns.RcodeSuccess
return m, nil
}

if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}

return resp, err
}

// Clear resources
func (u *DnsUpstream) Close() error {

// Close active connections
u.transport.Stop()
return nil
}

// Performs a synchronous query. It sends the message m via the conn
// c and waits for a reply. The conn c is not closed.
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {

// Establish a connection if needed (or reuse cached)
conn, err := u.transport.Dial(proto)
if err != nil {
return nil, err
}

// Write the request with a timeout
conn.SetWriteDeadline(time.Now().Add(u.timeout))
if err = conn.WriteMsg(query); err != nil {
conn.Close() // Not giving it back
return nil, err
}

// Write response with a timeout
conn.SetReadDeadline(time.Now().Add(u.timeout))
r, err = conn.ReadMsg()
if err != nil {
conn.Close() // Not giving it back
} else if err == nil && r.Id != query.Id {
err = dns.ErrId
conn.Close() // Not giving it back
}

if err == nil {
// Return it back to the connections cache if there were no errors
u.transport.Yield(conn)
}
return r, err
}
Loading

0 comments on commit 4a357f1

Please sign in to comment.