Skip to content

Commit ade3425

Browse files
committed
multi: fix recovering from bad connection.
1 parent 470271c commit ade3425

File tree

12 files changed

+108
-51
lines changed

12 files changed

+108
-51
lines changed

cluster_pipeline.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
6363
continue
6464
}
6565

66-
cn, err := client.conn()
66+
cn, _, err := client.conn()
6767
if err != nil {
6868
setCmdsErr(cmds, err)
6969
retErr = err

error.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,24 @@ func (err redisError) Error() string {
2626
}
2727

2828
func isNetworkError(err error) bool {
29-
if _, ok := err.(net.Error); ok || err == io.EOF {
29+
if err == io.EOF {
3030
return true
3131
}
32-
return false
32+
_, ok := err.(net.Error)
33+
return ok
34+
}
35+
36+
func isBadConn(cn *conn, ei error) bool {
37+
if cn.rd.Buffered() > 0 {
38+
return true
39+
}
40+
if ei == nil {
41+
return false
42+
}
43+
if _, ok := ei.(redisError); ok {
44+
return false
45+
}
46+
return true
3347
}
3448

3549
func isMovedError(err error) (moved bool, ask bool, addr string) {

multi.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ var errDiscard = errors.New("redis: Discard can be used only inside Exec")
1010

1111
// Multi implements Redis transactions as described in
1212
// http://redis.io/topics/transactions. It's NOT safe for concurrent use
13-
// by multiple goroutines, because Exec resets connection state.
13+
// by multiple goroutines, because Exec resets list of watched keys.
1414
// If you don't need WATCH it is better to use Pipeline.
1515
//
16-
// TODO(vmihailenco): rename to Tx
16+
// TODO(vmihailenco): rename to Tx and rework API
1717
type Multi struct {
1818
commandable
1919

@@ -34,6 +34,18 @@ func (c *Client) Multi() *Multi {
3434
return multi
3535
}
3636

37+
func (c *Multi) putConn(cn *conn, ei error) {
38+
var err error
39+
if isBadConn(cn, ei) {
40+
err = c.base.connPool.Remove(nil) // nil to force removal
41+
} else {
42+
err = c.base.connPool.Put(cn)
43+
}
44+
if err != nil {
45+
log.Printf("redis: putConn failed: %s", err)
46+
}
47+
}
48+
3749
func (c *Multi) process(cmd Cmder) {
3850
if c.cmds == nil {
3951
c.base.process(cmd)
@@ -112,15 +124,18 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
112124
return []Cmder{}, nil
113125
}
114126

115-
cn, err := c.base.conn()
127+
// Strip MULTI and EXEC commands.
128+
retCmds := cmds[1 : len(cmds)-1]
129+
130+
cn, _, err := c.base.conn()
116131
if err != nil {
117-
setCmdsErr(cmds[1:len(cmds)-1], err)
118-
return cmds[1 : len(cmds)-1], err
132+
setCmdsErr(retCmds, err)
133+
return retCmds, err
119134
}
120135

121136
err = c.execCmds(cn, cmds)
122-
c.base.putConn(cn, err)
123-
return cmds[1 : len(cmds)-1], err
137+
c.putConn(cn, err)
138+
return retCmds, err
124139
}
125140

126141
func (c *Multi) execCmds(cn *conn, cmds []Cmder) error {

multi_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,30 @@ var _ = Describe("Multi", func() {
119119
Expect(get.Val()).To(Equal("20000"))
120120
})
121121

122+
It("should recover from bad connection", func() {
123+
// Put bad connection in the pool.
124+
cn, _, err := client.Pool().Get()
125+
Expect(err).NotTo(HaveOccurred())
126+
127+
cn.SetNetConn(&badConn{})
128+
err = client.Pool().Put(cn)
129+
Expect(err).NotTo(HaveOccurred())
130+
131+
multi := client.Multi()
132+
defer func() {
133+
Expect(multi.Close()).NotTo(HaveOccurred())
134+
}()
135+
136+
_, err = multi.Exec(func() error {
137+
multi.Ping()
138+
return nil
139+
})
140+
Expect(err).To(MatchError("bad connection"))
141+
142+
_, err = multi.Exec(func() error {
143+
multi.Ping()
144+
return nil
145+
})
146+
Expect(err).NotTo(HaveOccurred())
147+
})
122148
})

pipeline.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
8888

8989
failedCmds := cmds
9090
for i := 0; i <= pipe.client.opt.MaxRetries; i++ {
91-
cn, err := pipe.client.conn()
91+
cn, _, err := pipe.client.conn()
9292
if err != nil {
9393
setCmdsErr(failedCmds, err)
9494
return cmds, err

pool.go

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var (
1818

1919
type pool interface {
2020
First() *conn
21-
Get() (*conn, error)
21+
Get() (*conn, bool, error)
2222
Put(*conn) error
2323
Remove(*conn) error
2424
Len() int
@@ -212,33 +212,36 @@ func (p *connPool) new() (*conn, error) {
212212
}
213213

214214
// Get returns existed connection from the pool or creates a new one.
215-
func (p *connPool) Get() (*conn, error) {
215+
func (p *connPool) Get() (cn *conn, isNew bool, err error) {
216216
if p.closed() {
217-
return nil, errClosed
217+
err = errClosed
218+
return
218219
}
219220

220221
// Fetch first non-idle connection, if available.
221-
if cn := p.First(); cn != nil {
222-
return cn, nil
222+
if cn = p.First(); cn != nil {
223+
return
223224
}
224225

225226
// Try to create a new one.
226227
if p.conns.Reserve() {
227-
cn, err := p.new()
228+
cn, err = p.new()
228229
if err != nil {
229230
p.conns.Remove(nil)
230-
return nil, err
231+
return
231232
}
232233
p.conns.Add(cn)
233-
return cn, nil
234+
isNew = true
235+
return
234236
}
235237

236238
// Otherwise, wait for the available connection.
237-
if cn := p.wait(); cn != nil {
238-
return cn, nil
239+
if cn = p.wait(); cn != nil {
240+
return
239241
}
240242

241-
return nil, errPoolTimeout
243+
err = errPoolTimeout
244+
return
242245
}
243246

244247
func (p *connPool) Put(cn *conn) error {
@@ -327,8 +330,8 @@ func (p *singleConnPool) First() *conn {
327330
return p.cn
328331
}
329332

330-
func (p *singleConnPool) Get() (*conn, error) {
331-
return p.cn, nil
333+
func (p *singleConnPool) Get() (*conn, bool, error) {
334+
return p.cn, false, nil
332335
}
333336

334337
func (p *singleConnPool) Put(cn *conn) error {
@@ -382,24 +385,25 @@ func (p *stickyConnPool) First() *conn {
382385
return cn
383386
}
384387

385-
func (p *stickyConnPool) Get() (*conn, error) {
388+
func (p *stickyConnPool) Get() (cn *conn, isNew bool, err error) {
386389
defer p.mx.Unlock()
387390
p.mx.Lock()
388391

389392
if p.closed {
390-
return nil, errClosed
393+
err = errClosed
394+
return
391395
}
392396
if p.cn != nil {
393-
return p.cn, nil
397+
cn = p.cn
398+
return
394399
}
395400

396-
cn, err := p.pool.Get()
401+
cn, isNew, err = p.pool.Get()
397402
if err != nil {
398-
return nil, err
403+
return
399404
}
400405
p.cn = cn
401-
402-
return p.cn, nil
406+
return
403407
}
404408

405409
func (p *stickyConnPool) put() (err error) {

pool_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ var _ = Describe("pool", func() {
107107
})
108108

109109
It("should remove broken connections", func() {
110-
cn, err := client.Pool().Get()
110+
cn, _, err := client.Pool().Get()
111111
Expect(err).NotTo(HaveOccurred())
112112
Expect(cn.Close()).NotTo(HaveOccurred())
113113
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
@@ -141,12 +141,12 @@ var _ = Describe("pool", func() {
141141
pool := client.Pool()
142142

143143
// Reserve one connection.
144-
cn, err := client.Pool().Get()
144+
cn, _, err := client.Pool().Get()
145145
Expect(err).NotTo(HaveOccurred())
146146

147147
// Reserve the rest of connections.
148148
for i := 0; i < 9; i++ {
149-
_, err := client.Pool().Get()
149+
_, _, err := client.Pool().Get()
150150
Expect(err).NotTo(HaveOccurred())
151151
}
152152

@@ -191,7 +191,7 @@ func BenchmarkPool(b *testing.B) {
191191

192192
b.RunParallel(func(pb *testing.PB) {
193193
for pb.Next() {
194-
conn, err := pool.Get()
194+
conn, _, err := pool.Get()
195195
if err != nil {
196196
b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
197197
}

pubsub.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
4747
}
4848

4949
func (c *PubSub) subscribe(cmd string, channels ...string) error {
50-
cn, err := c.conn()
50+
cn, _, err := c.conn()
5151
if err != nil {
5252
return err
5353
}
@@ -112,7 +112,7 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
112112
}
113113

114114
func (c *PubSub) Ping(payload string) error {
115-
cn, err := c.conn()
115+
cn, _, err := c.conn()
116116
if err != nil {
117117
return err
118118
}
@@ -208,14 +208,16 @@ func newMessage(reply []interface{}) (interface{}, error) {
208208
// is not received in time. This is low-level API and most clients
209209
// should use ReceiveMessage.
210210
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
211-
cn, err := c.conn()
211+
cn, _, err := c.conn()
212212
if err != nil {
213213
return nil, err
214214
}
215215
cn.ReadTimeout = timeout
216216

217217
cmd := NewSliceCmd()
218-
if err := cmd.readReply(cn); err != nil {
218+
err = cmd.readReply(cn)
219+
c.putConn(cn, err)
220+
if err != nil {
219221
return nil, err
220222
}
221223
return newMessage(cmd.Val())
@@ -229,7 +231,7 @@ func (c *PubSub) Receive() (interface{}, error) {
229231
}
230232

231233
func (c *PubSub) reconnect() {
232-
c.connPool.Remove(nil) // close current connection
234+
c.connPool.Remove(nil) // nil to force removal
233235
if len(c.channels) > 0 {
234236
if err := c.Subscribe(c.channels...); err != nil {
235237
log.Printf("redis: Subscribe failed: %s", err)

pubsub_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ var _ = Describe("PubSub", func() {
254254
Expect(err).NotTo(HaveOccurred())
255255
defer pubsub.Close()
256256

257-
cn, err := pubsub.Pool().Get()
257+
cn, _, err := pubsub.Pool().Get()
258258
Expect(err).NotTo(HaveOccurred())
259259
cn.SetNetConn(&badConn{
260260
readErr: errTimeout,

redis.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,16 @@ func (c *baseClient) String() string {
1616
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
1717
}
1818

19-
func (c *baseClient) conn() (*conn, error) {
19+
func (c *baseClient) conn() (*conn, bool, error) {
2020
return c.connPool.Get()
2121
}
2222

2323
func (c *baseClient) putConn(cn *conn, ei error) {
2424
var err error
25-
if cn.rd.Buffered() > 0 {
25+
if isBadConn(cn, ei) {
2626
err = c.connPool.Remove(cn)
27-
} else if ei == nil {
28-
err = c.connPool.Put(cn)
29-
} else if _, ok := ei.(redisError); ok {
30-
err = c.connPool.Put(cn)
3127
} else {
32-
err = c.connPool.Remove(cn)
28+
err = c.connPool.Put(cn)
3329
}
3430
if err != nil {
3531
log.Printf("redis: putConn failed: %s", err)
@@ -42,7 +38,7 @@ func (c *baseClient) process(cmd Cmder) {
4238
cmd.reset()
4339
}
4440

45-
cn, err := c.conn()
41+
cn, _, err := c.conn()
4642
if err != nil {
4743
cmd.setErr(err)
4844
return

redis_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ var _ = Describe("Client", func() {
157157
})
158158

159159
// Put bad connection in the pool.
160-
cn, err := client.Pool().Get()
160+
cn, _, err := client.Pool().Get()
161161
Expect(err).NotTo(HaveOccurred())
162162

163163
cn.SetNetConn(&badConn{})

ring.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) {
313313

314314
for name, cmds := range cmdsMap {
315315
client := pipe.ring.shards[name].Client
316-
cn, err := client.conn()
316+
cn, _, err := client.conn()
317317
if err != nil {
318318
setCmdsErr(cmds, err)
319319
if retErr == nil {

0 commit comments

Comments
 (0)