-
Notifications
You must be signed in to change notification settings - Fork 20
/
trie.go
106 lines (85 loc) · 2.04 KB
/
trie.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package trie
import (
"context"
"errors"
"log/slog"
"net/netip"
"github.com/Asutorufa/yuhaiin/pkg/log"
"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
"github.com/Asutorufa/yuhaiin/pkg/net/trie/cidr"
"github.com/Asutorufa/yuhaiin/pkg/net/trie/domain"
"github.com/Asutorufa/yuhaiin/pkg/utils/yerror"
)
type Trie[T any] struct {
cidr *cidr.Cidr[T]
domain *domain.Fqdn[T]
}
func (x *Trie[T]) Insert(str string, mark T) {
if str == "" {
return
}
ipNet, err := netip.ParsePrefix(str)
if err == nil {
x.cidr.InsertCIDR(ipNet, mark)
return
}
if ip, err := netip.ParseAddr(str); err == nil {
mask := 128
if ip.Is4() {
mask = 32
}
x.cidr.InsertIP(ip, mask, mark)
return
}
x.domain.Insert(str, mark)
}
var ErrSkipResolver = errors.New("skip resolve domain")
var SkipResolver = netapi.ErrorResolver(func(domain string) error { return ErrSkipResolver })
func (x *Trie[T]) Search(ctx context.Context, addr netapi.Address) (mark T, ok bool) {
if addr.Type() == netapi.IP {
return x.cidr.SearchIP(yerror.Must(addr.IP(ctx)))
}
if mark, ok = x.domain.Search(addr); ok {
return
}
if ips, err := addr.IP(ctx); err == nil {
mark, ok = x.cidr.SearchIP(ips)
} else if !errors.Is(err, ErrSkipResolver) {
log.Warn("dns lookup failed, skip match ip", slog.Any("addr", addr), slog.Any("err", err))
}
return
}
func (x *Trie[T]) Remove(str string) {
if str == "" {
return
}
ipNet, err := netip.ParsePrefix(str)
if err == nil {
x.cidr.RemoveCIDR(ipNet)
return
}
if ip, err := netip.ParseAddr(str); err == nil {
mask := 128
if ip.Is4() {
mask = 32
}
x.cidr.RemoveIP(ip, mask)
return
}
x.domain.Remove(str)
}
func (x *Trie[T]) SearchWithDefault(ctx context.Context, addr netapi.Address, defaultT T) T {
t, ok := x.Search(ctx, addr)
if ok {
return t
}
return defaultT
}
func (x *Trie[T]) Clear() error {
x.cidr = cidr.NewCidrMapper[T]()
x.domain = domain.NewDomainMapper[T]()
return nil
}
func NewTrie[T any]() *Trie[T] {
return &Trie[T]{cidr: cidr.NewCidrMapper[T](), domain: domain.NewDomainMapper[T]()}
}