forked from sbilly/go-windivert2
/
ipfilter.go
72 lines (56 loc) · 1009 Bytes
/
ipfilter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package utils
import (
"net"
"sync"
"github.com/imgk/shadow/utils/iptree"
)
type IPFilter struct {
sync.RWMutex
Tree *iptree.Tree
}
func NewIPFilter() *IPFilter {
f := &IPFilter{
RWMutex: sync.RWMutex{},
Tree: iptree.NewTree(),
}
return f
}
func (f *IPFilter) Reset() {
f.Lock()
f.UnsafeReset()
f.Unlock()
}
func (f *IPFilter) UnsafeReset() {
f.Tree = iptree.NewTree()
}
func (f *IPFilter) Add(s string) error {
f.Lock()
err := f.UnsafeAdd(s)
f.Unlock()
return err
}
func (f *IPFilter) UnsafeAdd(s string) error {
ip := net.ParseIP(s)
if ip != nil {
return f.addIP(ip)
}
_, ipNet, err := net.ParseCIDR(s)
if err != nil {
return err
}
return f.addCIDR(ipNet)
}
func (f *IPFilter) addIP(ip net.IP) error {
f.Tree.InplaceInsertIP(ip, nil)
return nil
}
func (f *IPFilter) addCIDR(ip *net.IPNet) error {
f.Tree.InplaceInsertNet(ip, nil)
return nil
}
func (f *IPFilter) Lookup(ip net.IP) bool {
f.RLock()
_, ok := f.Tree.GetByIP(ip)
f.RUnlock()
return ok
}