forked from shadowsocks/shadowsocks-go
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tcpConn.go
239 lines (214 loc) · 6.08 KB
/
tcpConn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
package shadowsocks
import (
"io"
"net"
"sync"
"time"
"go.uber.org/zap"
"github.com/arthurkiller/shadowsocks-go/encrypt"
)
var (
BufferSize = 0x1FFFF // BufferSize define pool size for buffer. By default, 32K will give for each buffer
writeBuffOffset = 0x7F // make 128 for buffer read offset enhance of aead cipher decryption
readBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, BufferSize, BufferSize)
},
}
writeBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, BufferSize+writeBuffOffset, BufferSize+writeBuffOffset)
},
}
)
// SecureConn is a secured connection with shadowsocks protocol
// also implements net.Conn interface
type SecureConn struct {
net.Conn
encrypt.Cipher
readBuf []byte
writeBuf []byte
dataCache []byte
datalen int // index for the dataCache
timeout int
}
// NewSecureConn creates a SecureConn with given cipher and timeout by warp the net.Conn
func NewSecureConn(c net.Conn, cipher encrypt.Cipher, timeout int) net.Conn {
conn := SecureConn{
Conn: c,
Cipher: cipher,
readBuf: readBufferPool.Get().([]byte),
writeBuf: writeBufferPool.Get().([]byte),
dataCache: writeBufferPool.Get().([]byte),
timeout: timeout,
}
if timeout > 0 {
conn.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second))
}
return &conn
}
// Close closes the connection and free the buffer
func (c *SecureConn) Close() error {
if c.readBuf != nil {
readBufferPool.Put(c.readBuf)
}
if c.writeBuf != nil {
writeBufferPool.Put(c.writeBuf)
}
if c.dataCache != nil {
writeBufferPool.Put(c.dataCache)
}
c.readBuf, c.dataCache = nil, nil
return c.Conn.Close()
}
// CloseRead closes the connection on read half
func (c *SecureConn) CloseRead() error {
if c.readBuf != nil {
readBufferPool.Put(c.readBuf)
}
if c.dataCache != nil {
writeBufferPool.Put(c.dataCache)
}
c.readBuf, c.dataCache = nil, nil
return c.Conn.(*net.TCPConn).CloseRead()
}
// CloseWrite closes the connection on write half
func (c *SecureConn) CloseWrite() error {
if c.writeBuf != nil {
writeBufferPool.Put(c.writeBuf)
}
c.writeBuf = nil
return c.Conn.(*net.TCPConn).CloseWrite()
}
// Read read the data from connection and decrypted with given cipher.
// the data may be cached and return with ErrAgain, that means more data is wantted for decryption
//
// SecureConn Read will take best affort to read the data and decrypt no matter what cipher it is.
// The aead cipher data stream was encrypted data block which with the definitely length. So the cipher
// has a cache inside for tcp stream data caching, and then return the data bolck read from stream if
// the length is enough.
//
// There get a second data cache here which caching the decrypted data in case the len of buffer is less than
// the data we decrypted. The remain data will append in the front of buffer for return when next read comes.
func (c *SecureConn) Read(b []byte) (n int, err error) {
// initializtion read the salt and init the decoder with salt and key
if c.DecryptorInited() {
_, err := io.ReadFull(c.Conn, c.readBuf[:c.InitBolckSize()])
if err != nil {
return -1, err
}
Logger.Debug("ss read iv", zap.Binary("iv", c.readBuf[:c.InitBolckSize()]))
err = c.InitDecryptor(c.readBuf[:c.InitBolckSize()])
if err != nil {
return -1, err
}
}
if c.datalen > 0 {
// consume the data first
ncp := copy(b, c.dataCache[:c.datalen])
copy(c.dataCache, c.dataCache[ncp:c.datalen])
c.datalen -= ncp
return ncp, nil
}
n, errR := c.Conn.Read(c.readBuf)
if n > 0 {
nn, err := c.Cipher.Decrypt(c.readBuf[:n], c.dataCache[c.datalen:])
errAgain:
if err != nil {
if err == encrypt.ErrAgain {
if nn > BufferSize {
Logger.Warn("aead error again require data length is larger than expect! that should not happen", zap.Int("n", nn), zap.Int("buffer", BufferSize))
nn = BufferSize
}
// handle the aead cipher ErrAgain, read again and decrypt
Logger.Debug("aead return errAgain, request more data", zap.Int("n", nn))
n, errR := c.Conn.Read(c.readBuf[:nn])
if errR != nil && errR != io.EOF {
return -1, errR
}
nn, err = c.Cipher.Decrypt(c.readBuf[:n], c.dataCache[c.datalen:])
goto errAgain
}
return -1, err
}
c.datalen += nn
nc := copy(b, c.dataCache[:c.datalen])
copy(c.dataCache, c.dataCache[nc:c.datalen])
c.datalen -= nc
if errR != nil {
return nc, errR
}
return nc, nil
}
if errR != nil {
return n, errR
}
return n, nil
}
func (c *SecureConn) Write(b []byte) (n int, err error) {
if c.EncryptorInited() {
data, err := c.InitEncryptor()
if err != nil {
return -1, err
}
Logger.Debug("ss write iv", zap.Binary("iv", data))
n, err = c.Conn.Write(data)
if err != nil {
return -1, err
}
if n != c.InitBolckSize() {
return -1, ErrUnexpectedIO
}
}
// FIXME TODO BUG if the datacache cannot cache extra data, here should get a bigger buffer
n, err = c.Encrypt(b, c.writeBuf)
if err != nil {
return -1, err
}
var start, nn int
for {
nn, err = c.Conn.Write(c.writeBuf[start:n])
if err != nil {
return nn, err
}
if nn < n {
start += nn
} else {
break
}
}
return nn, err
}
// secureListener is like net.Listener
type secureListener struct {
net.Listener
cipher encrypt.Cipher
timeout int
}
// Accept just like net.Listener.Accept()
func (ln *secureListener) Accept() (conn net.Conn, err error) {
conn, err = ln.Listener.Accept()
if err != nil {
return nil, err
}
ss := NewSecureConn(conn, ln.cipher.Copy(), ln.timeout)
if err != nil {
ss.Close()
return nil, err
}
return ss, nil
}
// Listen announces on the TCP address laddr and returns a TCP listener.
// Net must be "tcp", "tcp4", or "tcp6".
// If laddr has a port of 0, ListenTCP will choose an available port.
// The caller can use the Addr method of TCPListener to retrieve the chosen address.
func SecureListen(network, laddr string, cipher encrypt.Cipher, timeout int) (net.Listener, error) {
if cipher == nil {
return nil, ErrNilCipher
}
ln, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return &secureListener{ln, cipher, timeout}, nil
}