This repository has been archived by the owner on Jun 12, 2022. It is now read-only.
/
nat.go
400 lines (339 loc) · 10.5 KB
/
nat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
package service
import (
"crypto/cipher"
"errors"
"math/rand"
"net"
"sync"
"time"
onet "github.com/Shadowsocks-NET/outline-ss-server/net"
"github.com/Shadowsocks-NET/outline-ss-server/service/metrics"
ss "github.com/Shadowsocks-NET/outline-ss-server/shadowsocks"
"github.com/Shadowsocks-NET/outline-ss-server/socks"
"go.uber.org/zap"
wgreplay "golang.zx2c4.com/wireguard/replay"
)
type natconn struct {
// For legacy Shadowsocks servers, each natconn is mapped to one session,
// which is referenced here.
//
// For Shadowsocks 2022 servers, look up the session table instead.
session *session
// swg is session wait group. Wait until no sessions exist on this natconn,
// then this natconn can be removed.
swg sync.WaitGroup
// Reference to access key's cipher.
cipher *ss.Cipher
keyID string
// We store the client location in the NAT map to avoid recomputing it
// for every downstream packet in a UDP-based connection.
clientLocation string
// oobCache stores the interface index and IP where the requests are received from the client.
// This is used to send UDP packets from the correct interface and IP.
oobCache []byte
// The UDPConn where the last client packet was received from.
lastSeenConn onet.UDPPacketConn
}
func (c *natconn) WriteToUDP(buf []byte, dst *net.UDPAddr) (int, error) {
return c.session.WriteToUDP(buf, dst)
}
func (c *natconn) ReadFromUDP(buf []byte) (int, *net.UDPAddr, error) {
return c.session.ReadFromUDP(buf)
}
// timedCopy copies from targetConn to clientConn until read timeout.
// Pass Shadowsocks 2022 session as ses.
func (c *natconn) timedCopy(ses *session, sm metrics.ShadowsocksMetrics) {
// pkt is used for in-place encryption of downstream UDP packets, with the layout
// [padding?][salt][address][body][tag][extra]
// Padding is only used if the address is IPv4.
pkt := make([]byte, UDPPacketBufferSize)
saltSize := c.cipher.SaltSize()
cipherConfig := c.cipher.Config()
// Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6).
var bodyStart int
switch {
case cipherConfig.UDPHasSeparateHeader:
bodyStart = ss.UDPServerMessageHeaderFixedLength + ss.MaxPaddingLength + socks.SocksAddressIPv6Length
case cipherConfig.IsSpec2022:
bodyStart = 24 + ss.UDPServerMessageHeaderFixedLength + ss.MaxPaddingLength + socks.SocksAddressIPv6Length
default:
bodyStart = saltSize + socks.SocksAddressIPv6Length
}
expired := false
for {
var bodyLen, proxyClientBytes int
var lastSeenAddr *net.UDPAddr
connError := func() (connError *onet.ConnectionError) {
var (
raddr *net.UDPAddr
err error
headerStart int
pktStart int
buf []byte
)
// `readBuf` receives the plaintext body in `pkt`:
// [padding?][salt][address][body][tag][unused]
// |-- bodyStart --|[ readBuf ]
readBuf := pkt[bodyStart:]
switch {
case cipherConfig.IsSpec2022:
lastSeenAddr = ses.lastSeenAddr
bodyLen, raddr, err = ses.ReadFromUDP(readBuf)
default:
lastSeenAddr = c.session.lastSeenAddr
bodyLen, raddr, err = c.ReadFromUDP(readBuf)
}
if err != nil {
if netErr, ok := err.(net.Error); ok {
if netErr.Timeout() {
expired = true
return nil
}
}
if errors.Is(err, net.ErrClosed) { //FIXME: locate the bug and remove this mitigation.
expired = true
}
return onet.NewConnectionError("ERR_READ", "Failed to read from target", err)
}
socksAddrLen := socks.SocksAddressIPv6Length
if raddr.IP.To4() != nil {
socksAddrLen = socks.SocksAddressIPv4Length
}
switch {
case cipherConfig.IsSpec2022:
var paddingLen int
if raddr.Port == 53 {
paddingLen = rand.Intn(ss.MaxPaddingLength + 1)
}
headerStart = bodyStart - ss.UDPServerMessageHeaderFixedLength - paddingLen - socksAddrLen
ss.WriteUDPHeader(pkt[headerStart:], ss.HeaderTypeServerPacket, ses.ssid, ses.spid, ses.csid, raddr, nil, paddingLen)
ses.spid++
default:
headerStart = bodyStart - socksAddrLen
socks.WriteUDPAddrAsSocksAddr(pkt[headerStart:], raddr)
}
// `plainTextBuf` concatenates the SOCKS address and body:
// [padding?][salt][address][body][tag][unused]
// |-- addrStart -|[plaintextBuf ]
plaintextBuf := pkt[headerStart : bodyStart+bodyLen]
switch {
case cipherConfig.UDPHasSeparateHeader:
pktStart = headerStart
case cipherConfig.IsSpec2022:
pktStart = headerStart - 24
default:
pktStart = headerStart - saltSize
}
// `packBuf` adds space for the salt and tag.
// `buf` shows the space that was used.
// [padding?][salt][address][body][tag][unused]
// [ packBuf ]
// [ buf ]
packBuf := pkt[pktStart:]
switch {
case cipherConfig.UDPHasSeparateHeader:
buf, err = ss.PackAesWithSeparateHeader(packBuf, plaintextBuf, c.cipher, ses.saead)
default:
buf, err = ss.Pack(packBuf, plaintextBuf, c.cipher)
}
if err != nil {
return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err)
}
proxyClientBytes, _, err = c.lastSeenConn.WriteMsgUDP(buf, c.oobCache, lastSeenAddr)
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err)
}
return nil
}()
status := "OK"
if connError != nil {
logger.Warn(connError.Message,
zap.Stringer("listenAddress", c.lastSeenConn.LocalAddr()),
zap.Stringer("clientAddress", lastSeenAddr),
zap.Error(connError.Cause),
)
status = connError.Status
}
if expired {
break
}
sm.AddUDPPacketFromTarget(c.clientLocation, c.keyID, status, bodyLen, proxyClientBytes)
}
}
// Packet NAT table
type natmap struct {
sync.RWMutex
keyConn map[string]*natconn
sidConn map[uint64]*session
timeout time.Duration
metrics metrics.ShadowsocksMetrics
running *sync.WaitGroup
}
func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap {
return &natmap{
keyConn: make(map[string]*natconn),
sidConn: make(map[uint64]*session),
timeout: timeout,
metrics: sm,
running: running,
}
}
func (m *natmap) GetByClientAddress(key string) *natconn {
m.RLock()
defer m.RUnlock()
return m.keyConn[key]
}
func (m *natmap) GetByClientSessionID(csid uint64) *session {
m.RLock()
defer m.RUnlock()
return m.sidConn[csid]
}
func (m *natmap) AddNatEntry(clientAddr *net.UDPAddr, clientConn onet.UDPPacketConn, cipher *ss.Cipher, clientLocation, keyID string, ses *session) *natconn {
entry := &natconn{
session: ses,
cipher: cipher,
keyID: keyID,
clientLocation: clientLocation,
lastSeenConn: clientConn,
}
m.Lock()
defer m.Unlock()
m.keyConn[clientAddr.String()] = entry
return entry
}
func (m *natmap) StartNatconn(clientAddr *net.UDPAddr, entry *natconn, cipherConfig ss.CipherConfig) {
m.metrics.AddUDPNatEntry()
m.running.Add(1)
go func() {
switch {
case cipherConfig.IsSpec2022:
entry.swg.Wait()
default:
entry.timedCopy(nil, m.metrics)
entry.session.targetConn.Close()
}
m.Lock()
delete(m.keyConn, clientAddr.String())
m.Unlock()
m.metrics.RemoveUDPNatEntry()
m.running.Done()
}()
}
func (m *natmap) AddSession(csid uint64, ses *session, entry *natconn) {
m.Lock()
m.sidConn[csid] = ses
m.Unlock()
entry.swg.Add(1)
go func() {
entry.timedCopy(ses, m.metrics)
ses.targetConn.Close()
m.Lock()
delete(m.sidConn, csid)
m.Unlock()
entry.swg.Done()
}()
}
func (m *natmap) Close() error {
m.Lock()
defer m.Unlock()
var err error
now := time.Now()
for _, pc := range m.keyConn {
if pc.session != nil {
if e := pc.session.targetConn.SetReadDeadline(now); e != nil {
err = e
}
}
}
for _, ses := range m.sidConn {
if e := ses.targetConn.SetReadDeadline(now); e != nil {
err = e
}
}
return err
}
type session struct {
// csid stores client session ID for Shadowsocks 2022 Edition methods.
csid []byte
// ssid stores server session ID for Shadowsocks 2022 Edition methods.
ssid []byte
// spid stores server packet ID for Shadowsocks 2022 Edition methods.
spid uint64
// Stores reference to target conn.
targetConn onet.UDPPacketConn
// NAT timeout to apply for non-DNS packets.
defaultTimeout time.Duration
// Current read deadline of targetConn. Used to avoid decreasing the
// deadline. Initially zero.
readDeadline time.Time
// If the connection has only sent one DNS query, it will close
// if it receives a DNS response.
fastClose sync.Once
// The UDPAddr where the last client packet was received from.
// Use this.String() as key to look up the NAT table.
lastSeenAddr *net.UDPAddr
// Unix epoch timestamp when the last client packet was received.
lastSeenTime int64
// cfilter is the client session's sliding window filter.
// It rejects duplicate or out-of-window incoming client packets.
cfilter *wgreplay.Filter
// Only used by 2022-blake3-aes-256-gcm.
// Initialized with client session subkey.
caead cipher.AEAD
// Only used by 2022-blake3-aes-256-gcm.
// Initialized with server session subkey.
saead cipher.AEAD
}
func newSession(csid, ssid []byte, defaultTimeout time.Duration, caead, saead cipher.AEAD) *session {
return &session{
csid: csid,
ssid: ssid,
defaultTimeout: defaultTimeout,
lastSeenTime: time.Now().Unix(),
cfilter: &wgreplay.Filter{},
caead: caead,
saead: saead,
}
}
func isDNS(addr *net.UDPAddr) bool {
return addr.Port == 53
}
func (s *session) onWrite(addr *net.UDPAddr) {
// Fast close is only allowed if there has been exactly one write,
// and it was a DNS query.
isDNS := isDNS(addr)
isFirstWrite := s.readDeadline.IsZero()
if !isDNS || !isFirstWrite {
// Disable fast close. (Idempotent.)
s.fastClose.Do(func() {})
}
timeout := s.defaultTimeout
if isDNS {
// Shorten timeout as required by RFC 5452 Section 10.
timeout = 17 * time.Second
}
newDeadline := time.Now().Add(timeout)
if newDeadline.After(s.readDeadline) {
s.readDeadline = newDeadline
s.targetConn.SetReadDeadline(newDeadline)
}
}
func (s *session) onRead(addr *net.UDPAddr) {
s.fastClose.Do(func() {
if isDNS(addr) {
// The next ReadFrom() should time out immediately.
s.targetConn.SetReadDeadline(time.Now())
}
})
}
func (s *session) WriteToUDP(buf []byte, dst *net.UDPAddr) (int, error) {
s.onWrite(dst)
return s.targetConn.WriteToUDP(buf, dst)
}
func (s *session) ReadFromUDP(buf []byte) (int, *net.UDPAddr, error) {
n, addr, err := s.targetConn.ReadFromUDP(buf)
if err == nil {
s.onRead(addr)
}
return n, addr, err
}