diff --git a/dnsrewrite_test.go b/dnsrewrite_test.go index c42c2f5..df4e0e7 100644 --- a/dnsrewrite_test.go +++ b/dnsrewrite_test.go @@ -140,7 +140,12 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) { |refused_host^$dnsrewrite=REFUSED |new_cname^$dnsrewrite=othercname +|new_mx^$dnsrewrite=NOERROR;MX;32 new_mx_host |new_txt^$dnsrewrite=NOERROR;TXT;new_txtcontent +|1.2.3.4.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new_ptr + +|https_record^$dnsrewrite=NOERROR;HTTPS;32 https_record_host alpn=h3 +|svcb_record^$dnsrewrite=NOERROR;SVCB;32 svcb_record_host alpn=h3 |https_type^$dnstype=HTTPS,dnsrewrite=REFUSED @@ -318,6 +323,24 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) { } }) + t.Run("new_mx", func(t *testing.T) { + res, ok := dnsEngine.Match(path.Base(t.Name())) + assert.False(t, ok) + + dnsr := res.DNSRewritesAll() + if assert.Len(t, dnsr, 1) { + nr := dnsr[0] + assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode) + assert.Equal(t, dns.TypeMX, nr.DNSRewrite.RRType) + + mx := &rules.DNSMX{ + Exchange: "new_mx_host", + Preference: 32, + } + assert.Equal(t, mx, nr.DNSRewrite.Value) + } + }) + t.Run("new_txt", func(t *testing.T) { res, ok := dnsEngine.Match(path.Base(t.Name())) assert.False(t, ok) @@ -331,6 +354,63 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) { } }) + t.Run("1.2.3.4.in-addr.arpa", func(t *testing.T) { + res, ok := dnsEngine.Match(path.Base(t.Name())) + assert.False(t, ok) + + dnsr := res.DNSRewritesAll() + if assert.Len(t, dnsr, 1) { + nr := dnsr[0] + assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode) + assert.Equal(t, dns.TypePTR, nr.DNSRewrite.RRType) + assert.Equal(t, "new_ptr", nr.DNSRewrite.Value) + } + }) + + t.Run("https_record", func(t *testing.T) { + res, ok := dnsEngine.Match(path.Base(t.Name())) + assert.False(t, ok) + + dnsr := res.DNSRewritesAll() + if assert.Len(t, dnsr, 1) { + nr := dnsr[0] + assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode) + assert.Equal(t, dns.TypeHTTPS, nr.DNSRewrite.RRType) + + p := map[string]string{ + "alpn": "h3", + } + svcb := &rules.DNSSVCB{ + Params: p, + Target: "https_record_host", + Priority: 32, + } + assert.Equal(t, svcb, nr.DNSRewrite.Value) + } + }) + + t.Run("svcb_record", func(t *testing.T) { + res, ok := dnsEngine.Match(path.Base(t.Name())) + assert.False(t, ok) + + dnsr := res.DNSRewritesAll() + if assert.Len(t, dnsr, 1) { + nr := dnsr[0] + assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode) + assert.Equal(t, dns.TypeSVCB, nr.DNSRewrite.RRType) + + p := map[string]string{ + "alpn": "h3", + } + svcb := &rules.DNSSVCB{ + Params: p, + Target: "svcb_record_host", + Priority: 32, + } + assert.Equal(t, svcb, nr.DNSRewrite.Value) + } + }) + t.Run("https_type", func(t *testing.T) { r := DNSRequest{ Hostname: path.Base(t.Name()), diff --git a/rules/dnsrewrite.go b/rules/dnsrewrite.go index 6d4ce45..2bd13ee 100644 --- a/rules/dnsrewrite.go +++ b/rules/dnsrewrite.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "strconv" "strings" "github.com/miekg/dns" @@ -16,25 +17,35 @@ type RCode = int // type. type RRType = uint16 -// RRValue is the value of a resource record. If the coresponding RR is either -// dns.TypeA or dns.TypeAAAA, the underlying type of RRValue is net.IP. If the -// RR is dns.TypeTXT, the underlying type of Value is string. Otherwise, -// currently, it is nil. New types may be added in the future. +// RRValue is the value of a resource record. +// +// If the coresponding RRType is either dns.TypeA or dns.TypeAAAA, the +// underlying type of RRValue is net.IP. +// +// If the RRType is dns.TypeMX, the underlying value is a non-nil *DNSMX. +// +// If the RRType is either dns.TypePTR or dns.TypeTXT, the underlying type of +// Value is string. +// +// If the RRType is either dns.TypeHTTPS or dns.TypeSVCB, the underlying value +// is a non-nil *DNSSVCB. +// +// Otherwise, currently, it is nil. New types may be added in the future. type RRValue = interface{} // DNSRewrite is a DNS rewrite ($dnsrewrite) rule. type DNSRewrite struct { - // RCode is the new DNS RCODE. - RCode RCode - // RRType is the new DNS resource record (RR) type. It is only non-zero - // if RCode is dns.RCodeSuccess. - RRType RRType // Value is the value for the record. See the RRValue documentation for // more details. Value RRValue // NewCNAME is the new CNAME. If set, clients must ignore other fields, // resolve the CNAME, and set the new A and AAAA records accordingly. NewCNAME string + // RCode is the new DNS RCODE. + RCode RCode + // RRType is the new DNS resource record (RR) type. It is only non-zero + // if RCode is dns.RCodeSuccess. + RRType RRType } // loadDNSRewrite loads the $dnsrewrite modifier. @@ -107,27 +118,107 @@ func loadDNSRewriteShort(s string) (rewrite *DNSRewrite, err error) { }, nil } -// loadDNSRewritesNormal loads the normal version for of the $dnsrewrite -// modifier. -func loadDNSRewriteNormal(rcodeStr, rrStr, valStr string) (rewrite *DNSRewrite, err error) { - rcode, ok := dns.StringToRcode[strings.ToUpper(rcodeStr)] - if !ok { - return nil, fmt.Errorf("unknown rcode: %q", rcodeStr) +// DNSMX is the type of RRValue values returned for MX records in DNS rewrites. +type DNSMX struct { + Exchange string + Preference uint16 +} + +// DNSSVCB is the type of RRValue values returned for HTTPS and SVCB records in +// dns rewrites. +// +// See https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-01. +type DNSSVCB struct { + Params map[string]string + Target string + Priority uint16 +} + +// dnsRewriteRRHandler is a function that parses values for specific resource +// record types. +type dnsRewriteRRHandler func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) + +// strDNSRewriteRRHandler is a simple DNS rewrite handler that returns +// a *DNSRewrite with Value st to valStr. +func strDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { + return &DNSRewrite{ + RCode: rcode, + RRType: rr, + Value: valStr, + }, nil +} + +// svcbDNSRewriteRRHandler is a DNS rewrite handler that parses SVCB and HTTPS +// rewrites. +// +// See https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-01. +// +// TODO(a.garipov): Currently, we only support the contiguous type of +// char-string values from the RFC. +func svcbDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { + var name string + switch rr { + case dns.TypeHTTPS: + name = "https" + case dns.TypeSVCB: + name = "svcb" + default: + return nil, fmt.Errorf("unsupported svcb-like rr type: %d", rr) } - if rcode != dns.RcodeSuccess { + fields := strings.Split(valStr, " ") + if len(fields) < 2 { + return nil, fmt.Errorf("invalid %s %q: need at least two fields", name, valStr) + } + + var prio64 uint64 + prio64, err = strconv.ParseUint(fields[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid %s priority: %w", name, err) + } + + if len(fields) == 2 { + v := &DNSSVCB{ + Priority: uint16(prio64), + Target: fields[1], + } + return &DNSRewrite{ - RCode: rcode, + RCode: rcode, + RRType: rr, + Value: v, }, nil } - rr, err := strToRR(rrStr) - if err != nil { - return nil, err + params := make(map[string]string, len(fields)-2) + for i, pair := range fields[2:] { + kv := strings.Split(pair, "=") + if l := len(kv); l != 2 { + return nil, fmt.Errorf("invalid %s param at index %d: got %d fields", name, i, l) + } + + // TODO(a.garipov): Validate for uniqueness? Validate against + // the currently specified list of params from the RFC? + params[kv[0]] = kv[1] } - switch rr { - case dns.TypeA: + v := &DNSSVCB{ + Priority: uint16(prio64), + Target: fields[1], + Params: params, + } + + return &DNSRewrite{ + RCode: rcode, + RRType: rr, + Value: v, + }, nil +} + +// dnsRewriteRRHandlers are the supported resource record types' rewrite +// handlers. +var dnsRewriteRRHandlers = map[RRType]dnsRewriteRRHandler{ + dns.TypeA: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { ip := net.ParseIP(valStr) if ip4 := ip.To4(); ip4 == nil { return nil, fmt.Errorf("invalid ipv4: %q", valStr) @@ -138,7 +229,9 @@ func loadDNSRewriteNormal(rcodeStr, rrStr, valStr string) (rewrite *DNSRewrite, RRType: rr, Value: ip, }, nil - case dns.TypeAAAA: + }, + + dns.TypeAAAA: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { ip := net.ParseIP(valStr) if ip == nil { return nil, fmt.Errorf("invalid ipv6: %q", valStr) @@ -151,20 +244,71 @@ func loadDNSRewriteNormal(rcodeStr, rrStr, valStr string) (rewrite *DNSRewrite, RRType: rr, Value: ip, }, nil - case dns.TypeCNAME: + }, + + dns.TypeCNAME: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { return &DNSRewrite{ NewCNAME: valStr, }, nil - case dns.TypeTXT: + }, + + dns.TypeMX: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) { + parts := strings.SplitN(valStr, " ", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid mx: %q", valStr) + } + + var pref64 uint64 + pref64, err = strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid mx preference: %w", err) + } + + v := &DNSMX{ + Exchange: parts[1], + Preference: uint16(pref64), + } + return &DNSRewrite{ RCode: rcode, RRType: rr, - Value: valStr, + Value: v, }, nil - default: + }, + + dns.TypePTR: strDNSRewriteRRHandler, + dns.TypeTXT: strDNSRewriteRRHandler, + + dns.TypeHTTPS: svcbDNSRewriteRRHandler, + dns.TypeSVCB: svcbDNSRewriteRRHandler, +} + +// loadDNSRewritesNormal loads the normal version for of the $dnsrewrite +// modifier. +func loadDNSRewriteNormal(rcodeStr, rrStr, valStr string) (rewrite *DNSRewrite, err error) { + rcode, ok := dns.StringToRcode[strings.ToUpper(rcodeStr)] + if !ok { + return nil, fmt.Errorf("unknown rcode: %q", rcodeStr) + } + + if rcode != dns.RcodeSuccess { + return &DNSRewrite{ + RCode: rcode, + }, nil + } + + rr, err := strToRRType(rrStr) + if err != nil { + return nil, err + } + + handler, ok := dnsRewriteRRHandlers[rr] + if !ok { return &DNSRewrite{ RCode: rcode, RRType: rr, }, nil } + + return handler(rcode, rr, valStr) } diff --git a/rules/dnsrewrite_test.go b/rules/dnsrewrite_test.go index 823719a..ae07ccd 100644 --- a/rules/dnsrewrite_test.go +++ b/rules/dnsrewrite_test.go @@ -50,7 +50,15 @@ func TestNetworkRule_Match_dnsRewrite(t *testing.T) { assert.Nil(t, err) assert.True(t, r.Match(req)) - r, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;mx;hello", -1) + r, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;mx;30 example.net", -1) + assert.Nil(t, err) + assert.True(t, r.Match(req)) + + r, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;svcb;30 example.net alpn=h3", -1) + assert.Nil(t, err) + assert.True(t, r.Match(req)) + + r, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;https;30 example.net", -1) assert.Nil(t, err) assert.True(t, r.Match(req)) @@ -59,6 +67,14 @@ func TestNetworkRule_Match_dnsRewrite(t *testing.T) { assert.True(t, r.Match(req)) }) + t.Run("success_reverse", func(t *testing.T) { + req := NewRequestForHostname("1.2.3.4.in-addr.arpa") + + r, err := NewNetworkRule("||1.2.3.4.in-addr.arpa^$dnsrewrite=noerror;ptr;example.net", -1) + assert.Nil(t, err) + assert.True(t, r.Match(req)) + }) + t.Run("parse_errors", func(t *testing.T) { _, err := NewNetworkRule("||example.org^$dnsrewrite=BADKEYWORD", -1) assert.NotNil(t, err) @@ -80,5 +96,20 @@ func TestNetworkRule_Match_dnsRewrite(t *testing.T) { _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;aaaa;127.0.0.1", -1) assert.NotNil(t, err) + + _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;mx;bad stuff", -1) + assert.NotNil(t, err) + + _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;mx;very bad stuff", -1) + assert.NotNil(t, err) + + _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;https;bad stuff", -1) + assert.NotNil(t, err) + + _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;svcb;bad stuff", -1) + assert.NotNil(t, err) + + _, err = NewNetworkRule("||example.org^$dnsrewrite=noerror;svcb;42 bad stuffs", -1) + assert.NotNil(t, err) }) } diff --git a/rules/rule.go b/rules/rule.go index 4cfab30..3134a04 100644 --- a/rules/rule.go +++ b/rules/rule.go @@ -115,9 +115,9 @@ func loadDomains(domains, sep string) (permittedDomains, restrictedDomains []str return } -// strToRR converts s to a DNS resource record (RR) type. s may be in any -// letter case. -func strToRR(s string) (rr RRType, err error) { +// strToRRType converts s to a DNS resource record (RR) type. s may be +// in any letter case. +func strToRRType(s string) (rr RRType, err error) { // TypeNone and TypeReserved are special cases in package dns. if strings.EqualFold(s, "none") || strings.EqualFold(s, "reserved") { return 0, errors.New("dns rr type is none or reserved") @@ -148,7 +148,7 @@ func loadDNSTypes(types string) (permittedTypes, restrictedTypes []RRType, err e rrStr = rrStr[1:] } - rr, err := strToRR(rrStr) + rr, err := strToRRType(rrStr) if err != nil { return nil, nil, fmt.Errorf("type %d (%q): %w", i, rrStr, err) } diff --git a/staticcheck.conf b/staticcheck.conf new file mode 100644 index 0000000..8da8c9b --- /dev/null +++ b/staticcheck.conf @@ -0,0 +1,9 @@ +checks = ["all"] +initialisms = [ + # See https://github.com/dominikh/go-tools/blob/master/config/config.go. + "inherit" +, "MX" +, "SVCB" +] +dot_import_whitelist = [] +http_status_code_whitelist = []