-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
rewrites.go
224 lines (185 loc) · 5.42 KB
/
rewrites.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
package filtering
import (
"fmt"
"net/netip"
"slices"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// Legacy DNS rewrites
// LegacyRewrite is a single legacy DNS rewrite record.
//
// Instances of *LegacyRewrite must never be nil.
type LegacyRewrite struct {
// Domain is the domain pattern for which this rewrite should work.
Domain string `yaml:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer"`
// IP is the IP address that should be used in the response if Type is
// dns.TypeA or dns.TypeAAAA.
IP netip.Addr `yaml:"-"`
// Type is the DNS record type: A, AAAA, or CNAME.
Type uint16 `yaml:"-"`
}
// equal returns true if the rw is equal to the other.
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
return rw.Domain == other.Domain && rw.Answer == other.Answer
}
// matchesQType returns true if the entry matches the question type qt.
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
// Add CNAMEs, since they match for all types requests.
if rw.Type == dns.TypeCNAME {
return true
}
// Reject types other than A and AAAA.
if qt != dns.TypeA && qt != dns.TypeAAAA {
return false
}
// If the types match or the entry is set to allow only the other type,
// include them.
return rw.Type == qt || rw.IP == netip.Addr{}
}
// normalize makes sure that the new or decoded entry is normalized with regards
// to domain name case, IP length, and so on.
//
// If rw is nil, it returns an errors.
func (rw *LegacyRewrite) normalize() (err error) {
if rw == nil {
return errors.Error("nil rewrite entry")
}
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
// use it in matchDomainWildcard instead of using strings.ToLower
// everywhere.
rw.Domain = strings.ToLower(rw.Domain)
switch rw.Answer {
case "AAAA":
rw.IP = netip.Addr{}
rw.Type = dns.TypeAAAA
return nil
case "A":
rw.IP = netip.Addr{}
rw.Type = dns.TypeA
return nil
default:
// Go on.
}
ip, err := netip.ParseAddr(rw.Answer)
if err != nil {
log.Debug("normalizing legacy rewrite: %s", err)
rw.Type = dns.TypeCNAME
return nil
}
rw.IP = ip
if ip.Is4() {
rw.Type = dns.TypeA
} else {
rw.Type = dns.TypeAAAA
}
return nil
}
// isWildcard returns true if pat is a wildcard domain pattern.
func isWildcard(pat string) bool {
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
}
// matchDomainWildcard returns true if host matches the wildcard pattern.
func matchDomainWildcard(host, wildcard string) (ok bool) {
return isWildcard(wildcard) && strings.HasSuffix(host, wildcard[1:])
}
// Compare is used to sort rewrites according to the following priority:
//
// 1. A and AAAA > CNAME;
// 2. wildcard > exact;
// 3. lower level wildcard > higher level wildcard;
func (rw *LegacyRewrite) Compare(b *LegacyRewrite) (res int) {
if rw.Type == dns.TypeCNAME {
if b.Type != dns.TypeCNAME {
return -1
}
} else if b.Type == dns.TypeCNAME {
return 1
}
if aIsWld, bIsWld := isWildcard(rw.Domain), isWildcard(b.Domain); aIsWld == bIsWld {
// Both are either wildcards or both aren't.
return len(b.Domain) - len(rw.Domain)
} else if aIsWld {
return 1
} else {
return -1
}
}
// prepareRewrites normalizes and validates all legacy DNS rewrites.
func (d *DNSFilter) prepareRewrites() (err error) {
for i, r := range d.conf.Rewrites {
err = r.normalize()
if err != nil {
return fmt.Errorf("at index %d: %w", i, err)
}
}
return nil
}
// findRewrites returns the list of matched rewrite entries. If rewrites are
// empty, but matched is true, the domain is found among the rewrite rules but
// not for this question type.
//
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
// host is matched exactly, wildcard entries aren't returned. If the host
// matched by wildcards, return the most specific for the question type.
func findRewrites(
entries []*LegacyRewrite,
host string,
qtype uint16,
) (rewrites []*LegacyRewrite, matched bool) {
for _, e := range entries {
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
continue
}
matched = true
if e.matchesQType(qtype) {
rewrites = append(rewrites, e)
}
}
if len(rewrites) == 0 {
return nil, matched
}
slices.SortFunc(rewrites, (*LegacyRewrite).Compare)
for i, r := range rewrites {
if isWildcard(r.Domain) {
// Don't use rewrites[:0], because we need to return at least one
// item here.
rewrites = rewrites[:max(1, i)]
break
}
}
return rewrites, matched
}
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
// be nil.
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
for _, rw := range rewrites {
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rw.IP == (netip.Addr{}) {
// "A"/"AAAA" exception: allow getting from upstream.
res.Reason = NotFilteredNotFound
return
}
res.IPList = append(res.IPList, rw.IP)
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
}
}
}
// cloneRewrites returns a deep copy of entries.
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
clone = make([]*LegacyRewrite, len(entries))
for i, rw := range entries {
clone[i] = &LegacyRewrite{
Domain: rw.Domain,
Answer: rw.Answer,
IP: rw.IP,
Type: rw.Type,
}
}
return clone
}