-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn.go
575 lines (534 loc) · 13.9 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
package red
import (
"crypto/sha1"
"encoding/hex"
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/alxarch/red/internal/pipeline"
"github.com/alxarch/red/resp"
)
// Conn is a redis client connection
type Conn struct {
noCopy noCopy //nolint:unused,structcheck
conn net.Conn
// err error
w PipelineWriter
r resp.Stream
options ConnOptions
managed bool
state pipeline.State
scripts map[Arg]string // Loaded scripts
// Pool fields
createdAt time.Time
lastUsedAt time.Time
pool *Pool
}
// WriteCommand writes a redis command to the pipeline buffer updating the state
func (conn *Conn) WriteCommand(name string, args ...Arg) error {
if conn.managed {
return errConnManaged
}
if !conn.options.Debug {
name, args = conn.rewriteCommand(name, args)
}
switch name {
case "CLIENT":
return fmt.Errorf("CLIENT commands not allowed")
case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE":
return fmt.Errorf("Subscribe commands not allowed")
}
if err := conn.w.WriteCommand(name, args...); err != nil {
_ = conn.Close()
return err
}
conn.updatePipeline(name, args...)
return nil
}
// DoCommand executes a redis command
func (conn *Conn) DoCommand(dest interface{}, name string, args ...Arg) error {
if err := conn.Err(); err != nil {
return err
}
if n := conn.state.CountReplies(); n > 0 {
return fmt.Errorf("Pending %d replies", n)
}
if err := conn.WriteCommand(name, args...); err != nil {
return err
}
if err := conn.Scan(dest); err != nil {
return err
}
return nil
}
// func (conn *Conn) clientScanValue(skip bool) (pipeline.Entry, resp.Value, error) {
// if err := conn.flush(); err != nil {
// return pipeline.Entry{}, resp.Value{}, err
// }
// for {
// entry, ok := conn.state.Pop()
// if !ok {
// return pipeline.Entry{}, resp.Value{}, ErrNoReplies
// }
// if entry.Skip() {
// continue
// }
// var v resp.Value
// var err error
// if skip {
// err = conn.discardValue(entry)
// } else {
// v, err = conn.readValue(entry)
// }
// if err != nil {
// conn.closeWithError(err)
// }
// return entry, v, err
// }
// }
type replyExec struct {
dest []interface{}
err error
}
func (r *replyExec) UnmarshalRESP(v resp.Value) error {
switch {
case v.NullArray():
return fmt.Errorf("MULTI/EXEC transaction WATCH failed")
case v.Type() == resp.TypeArray:
iter := v.Iter()
for i, x := range r.dest {
if !iter.More() {
return fmt.Errorf("Invalid multi size %d > %d", len(r.dest), v.Len())
}
if x != nil {
if err := iter.Value().Decode(x); err != nil {
r.dest[i] = err
} else {
r.dest[i] = nil
}
}
iter.Next()
}
if iter.More() {
return fmt.Errorf("Invalid target size %d < %d", len(r.dest), v.Len())
}
for _, x := range r.dest {
if err, ok := x.(error); ok && err != nil {
return err
}
}
return nil
case v.Err() != nil:
return fmt.Errorf("MULTI/EXEC transaction aborted %s", v.Err())
default:
return fmt.Errorf("Invalid exec reply %v", v.Any())
}
}
// ScanMulti scans the results of a MULTI/EXEC transaction
func (conn *Conn) ScanMulti(dest ...interface{}) error {
if err := conn.Err(); err != nil {
return err
}
if conn.managed {
return errConnManaged
}
if err := conn.flush(); err != nil {
return err
}
if conn.options.WriteOnly {
return errConnWriteOnly
}
entry := conn.state.Peek()
if !entry.Multi() {
return fmt.Errorf("Non multi entry ahead %v", entry)
}
for {
entry, ok := conn.state.Pop()
if !ok {
return ErrNoReplies
}
switch {
case entry.Skip():
continue
case entry.Multi():
var isOK AssertOK
if err := conn.scanValue(&isOK, entry); err != nil {
return fmt.Errorf("MULTI failed: %s", err)
}
case entry.Discard():
return fmt.Errorf("MULTI/EXEC transaction discarded")
case entry.Exec():
exec := replyExec{
dest: dest,
}
return conn.scanValue(&exec, entry)
case entry.Queued():
if err := conn.scanValue(nil, entry); err != nil {
return err
}
default:
return fmt.Errorf("Invalid MULTI/EXEC entry %v", entry)
}
}
}
// Scan decodes a reply to dest
// If so keep the deadline(?) so an appropriate timeout is set next time
// XXX (see below) Other solution: return an error on write if a blocking command skips the reply
// Edit: This is nuts, the timeout is on the server. Since we are writing
// the command in a pipeline we cannot know when the server will set the timeout
// It's plain and simple to only allow blocking commands on a `clean` connection
// and handle the timeouts appropriately there.
// XXX (see below) OTOH it's plausible that a blocking command could be the last step in
// a MULTI/EXEC transaction
// Edit: From testing via redis-cli it seems that the server does *NOT*
// respect the timeout when a blocking command is inside MULTI/EXEC block
// It executes the pop immediately if a value is available or return a nil response
// Edit: This is also the case for client reply skip followed by BLPOP...
// Edit: CLIENT REPLY SKIP and blocking commands don't mix well
// If an error occurs because the command was wrong nothing is returned
// Otherwise the SKIP is ignored and does *NOT* carry over to the next command
// Because of all these intricacies the best thing to do is to
// a) disallow client reply subcommand entirely and only use it internally - DONE
// b) ignore the timeout of blocking commands when queued - DONE
// c) maybe change the way timeout is stored in an `cmd.Entry`
// so that it stores the deadline when it is written to the pipeline
// This is not urgent as setting a lax deadline is not so harmful if
// the connection is healthy
func (conn *Conn) Scan(dest interface{}) error {
if err := conn.Err(); err != nil {
return err
}
if conn.managed {
return errConnManaged
}
if conn.options.WriteOnly {
return errConnWriteOnly
}
if err := conn.flush(); err != nil {
return err
}
for {
entry, ok := conn.state.Pop()
if !ok {
return ErrNoReplies
}
if !entry.Skip() {
return conn.scanValue(dest, entry)
}
}
}
// WriteQuick is a convenience wrapper for WriteCommand
func (conn *Conn) WriteQuick(name, key string, args ...string) error {
return conn.WriteCommand(name, QuickArgs(key, args...)...)
}
// ConnOptions holds connection options
type ConnOptions struct {
ReadBufferSize int // Size of the read buffer
WriteBufferSize int // Size of the write buffer
ReadTimeout time.Duration // If > 0 all reads will fail if exceeded
WriteTimeout time.Duration // If > 0 all writes will fail if exceeded
WriteOnly bool // WriteOnly connections return no replies
DB int // Redis DB index
KeyPrefix string // Prefix all keys
Auth string // Redis auth
Debug bool // Disables script injection
}
var (
// ErrNoReplies is returned when no more replies are expected on Scan
ErrNoReplies = errors.New("No more replies")
errConnClosed = errors.New("Connection closed")
errConnManaged = errors.New("Connection managed by client")
errConnWriteOnly = errors.New("Connection write only")
)
// // Managed checks if a connection is managed by a client
// func (conn *Conn) Managed() bool {
// return conn.managed
// }
// Dirty checkd if a connection has pending replies to scan
func (conn *Conn) Dirty() bool {
return conn.state.Dirty()
}
// Err checks if the connection has an error
func (conn *Conn) Err() error {
if conn != nil && conn.conn != nil {
return nil
}
return errConnClosed
}
// Close closes a redis connection
func (conn *Conn) Close() error {
if conn.pool != nil {
err := conn.pool.put(conn)
return err
}
if cn := conn.conn; conn != nil {
conn.conn = nil
return cn.Close()
}
return errConnClosed
}
// Reset resets the connection to a state as defined by the options
func (conn *Conn) Reset(options *ConnOptions) error {
if err := conn.Err(); err != nil {
return err
}
if conn.managed {
return errConnManaged
}
if options == nil {
options = &conn.options
} else {
conn.options = *options
}
state := &conn.state
if state.IsMulti() {
_ = conn.WriteCommand("DISCARD")
} else if state.IsWatch() {
_ = conn.WriteCommand("UNWATCH")
}
if options.WriteOnly {
_ = conn.WriteCommand("CLIENT", String("REPLY"), String("OFF"))
} else if state.IsReplyOFF() {
_ = conn.WriteCommand("CLIENT", String("REPLY"), String("ON"))
} else if state.IsReplySkip() {
_ = conn.WriteCommand("PING")
}
if DBIndexValid(options.DB) && int(state.DB()) != options.DB {
_ = conn.injectCommand("SELECT", Int(options.DB))
}
return conn.clear()
}
func (conn *Conn) clear() error {
if conn.options.WriteOnly {
_ = conn.WriteCommand("CLIENT", String("REPLY"), String("OFF"))
} else {
_ = conn.flush()
_ = conn.drain()
}
return conn.Err()
}
func (conn *Conn) rewriteCommand(name string, args []Arg) (string, []Arg) {
name = strings.ToUpper(name)
switch name {
case "EVAL":
// Inject scripts
if len(args) > 0 {
arg := args[0]
if sha1, ok := conn.scripts[arg]; ok {
args[0] = String(sha1)
return "EVALSHA", args
}
if script, ok := arg.Value().(string); ok {
sha1 := sha1Sum(script)
conn.scripts[arg] = sha1
conn.injectCommand("SCRIPT", String("LOAD"), String(script))
args[0] = String(sha1)
return "EVALSHA", args
}
}
}
return name, args
}
// writeCommandSkipReply writes a redis command skipping the reply
func (conn *Conn) injectCommand(name string, args ...Arg) error {
switch {
case conn.state.IsMulti():
return fmt.Errorf("Connection is in MULTI/EXEC transaction")
case conn.state.IsReplyOFF():
return conn.WriteCommand(name, args...)
case conn.options.WriteOnly:
return conn.WriteCommand(name, args...)
case conn.state.IsReplySkip():
return fmt.Errorf("Connection is already on CLIENT REPLY SKIP")
default:
// NOTE: any write error in conn.cmd is sticky so it will be returned
// by the conn.WriteCommand call at the end of the function
_ = conn.w.WriteCommand("CLIENT", String("REPLY"), String("SKIP"))
conn.updatePipeline("CLIENT", String("REPLY"), String("SKIP"))
return conn.WriteCommand(name, args...)
}
}
// flush flushes the pipeline buffer
func (conn *Conn) flush() error {
if err := conn.w.Flush(); err != nil {
_ = conn.Close()
return err
}
return nil
}
func (conn *Conn) drain() error {
for {
entry, ok := conn.state.Pop()
if !ok {
return ErrNoReplies
}
if entry.Skip() {
continue
}
if err := conn.scanValue(nil, entry); err != nil {
_ = conn.Close()
return err
}
// if !conn.state.Dirty() {
// return nil
// }
// if conn.state.Flush(&conn.replies) == 0 {
// return nil
// }
}
}
func sha1Sum(s string) string {
sum := sha1.Sum([]byte(s))
var dst [2 * sha1.Size]byte
hex.Encode(dst[:], sum[:])
return string(dst[:])
}
func (conn *Conn) resetTimeout(entry pipeline.Entry) error {
// Setup timeout
timeout := conn.options.ReadTimeout
if timeout < 0 {
timeout = 0
}
// Only manage blocking timeouts when the command is not part of MULTI/EXEC
if t, block := entry.Block(); block && !entry.Queued() {
if t > 0 {
timeout += t
} else {
timeout = -1
}
}
if timeout == 0 {
return nil
}
if timeout > 0 {
deadline := time.Now().Add(timeout)
return conn.conn.SetReadDeadline(deadline)
}
return conn.conn.SetReadDeadline(time.Time{})
}
func isDecodeError(err error) bool {
_, ok := err.(*resp.DecodeError)
return ok
}
func (conn *Conn) scanValue(dest interface{}, entry pipeline.Entry) error {
if err := conn.resetTimeout(entry); err != nil {
_ = conn.Close()
return err
}
if err := conn.r.Decode(dest); err != nil {
if !isDecodeError(err) {
_ = conn.Close()
}
return err
}
return nil
}
// func (conn *Conn) manage() {
// conn.managed = true
// }
// func (conn *Conn) unmanage() {
// conn.managed = false
// }
// func (conn *Conn) getClient() *Client {
// if conn.pool != nil {
// return conn.pool.getClient()
// }
// return new(Client)
// }
// func (conn *Conn) putClient(client *Client) {
// if conn.pool != nil {
// conn.pool.putClient(client)
// }
// }
// Auth authenticates a connection
func (conn *Conn) Auth(password string) error {
var ok AssertOK
if err := conn.DoCommand(&ok, "AUTH", String(password)); err != nil {
return fmt.Errorf("Authentication failed: %s", err)
}
return nil
}
func (conn *Conn) updatePipeline(name string, args ...Arg) {
switch name {
case "SELECT":
index := selectArg(args)
if 0 <= index && index < MaxDBIndex {
conn.state.Select(index)
} else {
conn.state.Command()
}
case "MULTI":
conn.state.Multi()
case "EXEC":
conn.state.Exec()
case "DISCARD":
conn.state.Discard()
case "WATCH":
conn.state.Watch(len(args))
case "UNWATCH":
conn.state.Unwatch()
case "CLIENT":
switch clientReplyArg(args) {
case "OFF":
conn.state.ReplyOFF()
case "ON":
conn.state.ReplyON()
case "SKIP":
conn.state.ReplySkip()
default:
conn.state.Command()
}
case "BLPOP", "BRPOP", "BRPOPLPUSH", "BZPOPMIN", "BZPOPMAX":
timeout := lastArgTimeout(args)
conn.state.Block(timeout)
default:
conn.state.Command()
}
}
func selectArg(args []Arg) int64 {
if len(args) > 0 {
// TODO: force arg to int64
if index, ok := args[0].Value().(int64); ok {
return index
}
}
return -1
}
func clientReplyArg(args []Arg) string {
if len(args) == 2 {
arg0, arg1 := args[0], args[1]
if s, ok := arg0.Value().(string); ok && strings.ToUpper(s) == "REPLY" {
if s, ok := arg1.Value().(string); ok {
return strings.ToUpper(s)
}
}
}
return ""
}
func lastArgTimeout(args []Arg) time.Duration {
if last := len(args) - 1; 1 <= last && last < len(args) {
arg := &args[last]
switch v := arg.Value().(type) {
case int64:
return time.Duration(v) * time.Millisecond
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return time.Duration(n) * time.Millisecond
}
}
return 0
}
type managedConn struct {
*Conn
}
func (m *managedConn) Close() error {
if conn := m.Conn; m.conn != nil {
m.Conn = nil
conn.managed = false
return nil
}
return errConnClosed
}