/
net.go
190 lines (163 loc) · 4 KB
/
net.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright (C) 2020-2021, IrineSistiana
//
// This file is part of mosdns.
//
// mosdns is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// mosdns is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https:// www.gnu.org/licenses/>.
package netlist
import (
"encoding/binary"
"errors"
"fmt"
"github.com/IrineSistiana/mosdns/v3/dispatcher/pkg/utils"
"net"
"strconv"
)
const (
maxUint64 = ^uint64(0)
)
var (
ErrInvalidIP = errors.New("invalid ip")
)
// IPv6 represents a ipv6 addr
type IPv6 [2]uint64
// mask is ipv6 IP network mask
type mask uint8
func getMask(m mask, offset uint8) uint64 {
if uint8(m) > 64*offset {
return ^(maxUint64 >> (uint8(m) - 64*offset))
}
return 0
}
// Net represents an ip network
type Net struct {
ip IPv6
mask mask
}
// NewNet returns a new IPNet, mask should be an ipv6 mask,
// which means you should +96 if you have an ipv4 mask.
func NewNet(ipv6 IPv6, m int) Net {
um := mask(m)
n := Net{
ip: ipv6,
mask: um,
}
for offset := uint8(0); offset < 2; offset++ {
n.ip[offset] &= getMask(um, offset)
}
return n
}
// Contains reports whether the net includes the ip.
func (n Net) Contains(ip IPv6) bool {
for offset := uint8(0); offset < 2; offset++ {
if ip[offset]&getMask(n.mask, offset) == n.ip[offset] {
continue
}
return false
}
return true
}
var v4InV6Prefix uint64 = 0xffff << 32
// Conv converts ip to type IPv6.
// ip should be an ipv4/6 address (with length 4 or 16)
// Conv will return ErrInvalidIP if ip has an invalid length.
func Conv(ip net.IP) (IPv6, error) {
switch len(ip) {
case 16:
ipv6 := IPv6{}
for i := 0; i < 2; i++ {
s := i * 8
ipv6[i] = binary.BigEndian.Uint64(ip[s : s+8])
}
return ipv6, nil
case 4:
return IPv6{0, uint64(binary.BigEndian.Uint32(ip)) + v4InV6Prefix}, nil
default:
return IPv6{}, ErrInvalidIP
}
}
type IPVersion uint8
const (
Version4 IPVersion = iota
Version6
)
func ParseIP(s string) (IPv6, IPVersion, error) {
ip := net.ParseIP(s)
if ip == nil {
return IPv6{}, 0, ErrInvalidIP
}
ipv6, err := Conv(ip)
if err != nil {
return IPv6{}, 0, err
}
var v IPVersion
if ip.To4() != nil {
v = Version4
} else {
v = Version6
}
return ipv6, v, nil
}
// ParseCIDR parses s as a CIDR notation IP address and prefix length.
// As defined in RFC 4632 and RFC 4291.
func ParseCIDR(s string) (Net, error) {
ipStr, maskStr, ok := utils.SplitString2(s, "/")
if ok { // has "/"
// ip
ipv6, version, err := ParseIP(ipStr)
if err != nil {
return Net{}, err
}
// mask
maskLen, err := strconv.ParseUint(maskStr, 10, 0)
if err != nil {
return Net{}, fmt.Errorf("invalid cidr mask %s", s)
}
// if string is a ipv4 addr, add 96
if version != Version6 {
maskLen = maskLen + 96
}
if maskLen > 128 {
return Net{}, fmt.Errorf("cidr mask %s overflow", s)
}
return NewNet(ipv6, int(maskLen)), nil
}
ipv6, _, err := ParseIP(s)
if err != nil {
return Net{}, err
}
return NewNet(ipv6, 128), nil
}
func (ip IPv6) ToNetIP() net.IP {
nip := make(net.IP, 16)
uint64ToBytes(ip, nip)
return nip
}
func (m mask) toNetMask() net.IPMask {
nMask := make(net.IPMask, 16)
uint64ToBytes([2]uint64{getMask(m, 0), getMask(m, 1)}, nMask)
return nMask
}
func uint64ToBytes(in [2]uint64, out []byte) {
binary.BigEndian.PutUint64(out[:8], in[0])
binary.BigEndian.PutUint64(out[8:], in[1])
}
func (n Net) ToNetIPNet() *net.IPNet {
nn := new(net.IPNet)
nn.IP = n.ip.ToNetIP()
nn.Mask = n.mask.toNetMask()
return nn
}
func (n Net) String() string {
return n.ToNetIPNet().String()
}