Skip to content

Commit

Permalink
rules: improve dns type matching
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Nov 24, 2020
1 parent 1032d62 commit 2d6df3d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
34 changes: 34 additions & 0 deletions dns_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,14 @@ func TestBadfilterRules(t *testing.T) {
func TestDNSEngine_MatchRequest_dnsType(t *testing.T) {
const rulesText = `
||simple^$dnstype=AAAA
||simple_case^$dnstype=aaaa
||reverse^$dnstype=~AAAA
||multiple^$dnstype=A|AAAA
||multiple_reverse^$dnstype=~A|~AAAA
||multiple_different^$dnstype=~A|AAAA
||simple_client^$client=127.0.0.1,dnstype=AAAA
||priority^$client=127.0.0.1
||priority^$client=127.0.0.1,dnstype=AAAA
`

ruleStorage := newTestRuleStorage(t, 1, rulesText)
Expand All @@ -411,6 +414,16 @@ func TestDNSEngine_MatchRequest_dnsType(t *testing.T) {
assert.False(t, ok)
})

t.Run("simple_case", func(t *testing.T) {
r := DNSRequest{Hostname: "simple_case", DNSType: dns.TypeAAAA}
_, ok := dnsEngine.MatchRequest(r)
assert.True(t, ok)

r.DNSType = dns.TypeA
_, ok = dnsEngine.MatchRequest(r)
assert.False(t, ok)
})

t.Run("reverse", func(t *testing.T) {
r := DNSRequest{Hostname: "reverse", DNSType: dns.TypeAAAA}
_, ok := dnsEngine.MatchRequest(r)
Expand Down Expand Up @@ -498,4 +511,25 @@ func TestDNSEngine_MatchRequest_dnsType(t *testing.T) {
_, ok = dnsEngine.MatchRequest(r)
assert.False(t, ok)
})

t.Run("priority", func(t *testing.T) {
r := DNSRequest{
Hostname: "priority",
DNSType: dns.TypeAAAA,
ClientIP: "127.0.0.1",
}

rules, ok := dnsEngine.MatchRequest(r)
assert.True(t, ok)
assert.Contains(t, rules.NetworkRule.Text(), "dnstype=")

r = DNSRequest{
Hostname: "priority",
DNSType: dns.TypeA,
ClientIP: "127.0.0.1",
}
rules, ok = dnsEngine.MatchRequest(r)
assert.True(t, ok)
assert.NotContains(t, rules.NetworkRule.Text(), "dnstype=")
})
}
6 changes: 6 additions & 0 deletions rules/network_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ func (f *NetworkRule) IsHigherPriority(r *NetworkRule) bool {
if len(f.permittedDomains) != 0 || len(f.restrictedDomains) != 0 {
count++
}
if len(f.permittedDNSTypes) != 0 || len(f.restrictedDNSTypes) != 0 {
count++
}
if len(f.permittedClientTags) != 0 || len(f.restrictedClientTags) != 0 {
count++
}
Expand All @@ -345,6 +348,9 @@ func (f *NetworkRule) IsHigherPriority(r *NetworkRule) bool {
if len(r.permittedDomains) != 0 || len(r.restrictedDomains) != 0 {
rCount++
}
if len(r.permittedDNSTypes) != 0 || len(r.restrictedDNSTypes) != 0 {
rCount++
}
if len(r.permittedClientTags) != 0 || len(r.restrictedClientTags) != 0 {
rCount++
}
Expand Down
4 changes: 4 additions & 0 deletions rules/network_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ func TestNetworkRulePriority(t *testing.T) {
// more modifiers -> less modifiers
compareRulesPriority(t, "||example.org$script,stylesheet", "||example.org$script", true)
compareRulesPriority(t, "||example.org$ctag=123,client=123", "||example.org$script", true)
compareRulesPriority(t, "||example.org$ctag=123,client=123,dnstype=AAAA", "||example.org$client=123,dnstype=AAAA", true)
}

func TestMatchSource(t *testing.T) {
Expand Down Expand Up @@ -773,6 +774,9 @@ func TestNetworkRule_Match_dnsType(t *testing.T) {
_, err = NewNetworkRule("||example.org^$dnstype=TXT|", -1)
assert.NotNil(t, err)

_, err = NewNetworkRule("||example.org^$dnstype=NONE", -1)
assert.NotNil(t, err)

_, err = NewNetworkRule("||example.org^$dnstype=INVALIDTYPE", -1)
assert.NotNil(t, err)
})
Expand Down
9 changes: 7 additions & 2 deletions rules/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,14 @@ func loadDNSTypes(types string) (permittedTypes []uint16, restrictedTypes []uint
t = t[1:]
}

rtype, ok := dns.StringToType[t]
// TypeNone and TypeReserved are special cases in package dns.
if strings.EqualFold(t, "none") || strings.EqualFold(t, "reserved") {
return nil, nil, fmt.Errorf("dns record type %d (%q) is none or reserved", i, t)
}

rtype, ok := dns.StringToType[strings.ToUpper(t)]
if !ok {
return nil, nil, fmt.Errorf("dns record type %d is invalid", i)
return nil, nil, fmt.Errorf("dns record type %d (%q) is invalid", i, t)
}

if restricted {
Expand Down

0 comments on commit 2d6df3d

Please sign in to comment.