Skip to content

Commit 110e93a

Browse files
committed
Simplify connection management with sticky connection pool. Fixes #260.
1 parent 0382d1e commit 110e93a

10 files changed

+139
-89
lines changed

error.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ func isNetworkError(err error) bool {
3333
return ok
3434
}
3535

36-
func isBadConn(cn *conn, ei error) bool {
37-
if cn.rd.Buffered() > 0 {
38-
return true
36+
func isBadConn(err error) bool {
37+
if err == nil {
38+
return false
3939
}
40-
if ei == nil {
40+
if _, ok := err.(redisError); ok {
4141
return false
4242
}
43-
if _, ok := ei.(redisError); ok {
43+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
4444
return false
4545
}
4646
return true

export_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ func (c *baseClient) Pool() pool {
66
return c.connPool
77
}
88

9+
func (c *PubSub) Pool() pool {
10+
return c.base.connPool
11+
}
12+
913
var NewConnDialer = newConnDialer
1014

1115
func (cn *conn) SetNetConn(netcn net.Conn) {

main_test.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"os/exec"
88
"path/filepath"
99
"sync/atomic"
10-
"syscall"
1110
"testing"
1211
"time"
1312

@@ -243,10 +242,6 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
243242

244243
//------------------------------------------------------------------------------
245244

246-
var (
247-
errTimeout = syscall.ETIMEDOUT
248-
)
249-
250245
type badConnError string
251246

252247
func (e badConnError) Error() string { return string(e) }

multi.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ func (c *Client) Multi() *Multi {
4545
return multi
4646
}
4747

48-
func (c *Multi) putConn(cn *conn, err error) {
49-
if isBadConn(cn, err) {
50-
// Close current connection.
51-
c.base.connPool.(*stickyConnPool).Reset(err)
52-
} else {
53-
err := c.base.connPool.Put(cn)
54-
if err != nil {
55-
Logger.Printf("pool.Put failed: %s", err)
56-
}
57-
}
58-
}
59-
6048
func (c *Multi) process(cmd Cmder) {
6149
if c.cmds == nil {
6250
c.base.process(cmd)
@@ -145,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
145133
}
146134

147135
err = c.execCmds(cn, cmds)
148-
c.putConn(cn, err)
136+
c.base.putConn(cn, err)
149137
return retCmds, err
150138
}
151139

multi_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,31 @@ var _ = Describe("Multi", func() {
166166
})
167167
Expect(err).NotTo(HaveOccurred())
168168
})
169+
170+
It("should recover from bad connection when there are no commands", func() {
171+
// Put bad connection in the pool.
172+
cn, _, err := client.Pool().Get()
173+
Expect(err).NotTo(HaveOccurred())
174+
175+
cn.SetNetConn(&badConn{})
176+
err = client.Pool().Put(cn)
177+
Expect(err).NotTo(HaveOccurred())
178+
179+
{
180+
tx, err := client.Watch("key")
181+
Expect(err).To(MatchError("bad connection"))
182+
Expect(tx).To(BeNil())
183+
}
184+
185+
{
186+
tx, err := client.Watch("key")
187+
Expect(err).NotTo(HaveOccurred())
188+
189+
err = tx.Ping().Err()
190+
Expect(err).NotTo(HaveOccurred())
191+
192+
err = tx.Close()
193+
Expect(err).NotTo(HaveOccurred())
194+
}
195+
})
169196
})

pool.go

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,14 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) {
246246

247247
// Try to create a new one.
248248
if p.conns.Reserve() {
249+
isNew = true
250+
249251
cn, err = p.new()
250252
if err != nil {
251253
p.conns.Remove(nil)
252254
return
253255
}
254256
p.conns.Add(cn)
255-
isNew = true
256257
return
257258
}
258259

@@ -481,13 +482,13 @@ func (p *stickyConnPool) Put(cn *conn) error {
481482
return nil
482483
}
483484

484-
func (p *stickyConnPool) remove(reason error) (err error) {
485-
err = p.pool.Remove(p.cn, reason)
485+
func (p *stickyConnPool) remove(reason error) error {
486+
err := p.pool.Remove(p.cn, reason)
486487
p.cn = nil
487488
return err
488489
}
489490

490-
func (p *stickyConnPool) Remove(cn *conn, _ error) error {
491+
func (p *stickyConnPool) Remove(cn *conn, reason error) error {
491492
defer p.mx.Unlock()
492493
p.mx.Lock()
493494
if p.closed {
@@ -499,7 +500,7 @@ func (p *stickyConnPool) Remove(cn *conn, _ error) error {
499500
if cn != nil && p.cn != cn {
500501
panic("p.cn != cn")
501502
}
502-
return nil
503+
return p.remove(reason)
503504
}
504505

505506
func (p *stickyConnPool) Len() int {
@@ -522,15 +523,6 @@ func (p *stickyConnPool) FreeLen() int {
522523

523524
func (p *stickyConnPool) Stats() *PoolStats { return nil }
524525

525-
func (p *stickyConnPool) Reset(reason error) (err error) {
526-
p.mx.Lock()
527-
if p.cn != nil {
528-
err = p.remove(reason)
529-
}
530-
p.mx.Unlock()
531-
return err
532-
}
533-
534526
func (p *stickyConnPool) Close() error {
535527
defer p.mx.Unlock()
536528
p.mx.Lock()

pubsub.go

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ func (c *Client) Publish(channel, message string) *IntCmd {
1717
// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by
1818
// multiple goroutines.
1919
type PubSub struct {
20-
*baseClient
20+
base *baseClient
2121

2222
channels []string
2323
patterns []string
24+
25+
nsub int // number of active subscriptions
2426
}
2527

2628
// Deprecated. Use Subscribe/PSubscribe instead.
2729
func (c *Client) PubSub() *PubSub {
2830
return &PubSub{
29-
baseClient: &baseClient{
31+
base: &baseClient{
3032
opt: c.opt,
3133
connPool: newStickyConnPool(c.connPool, false),
3234
},
@@ -46,7 +48,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
4648
}
4749

4850
func (c *PubSub) subscribe(cmd string, channels ...string) error {
49-
cn, _, err := c.conn()
51+
cn, _, err := c.base.conn()
5052
if err != nil {
5153
return err
5254
}
@@ -65,6 +67,7 @@ func (c *PubSub) Subscribe(channels ...string) error {
6567
err := c.subscribe("SUBSCRIBE", channels...)
6668
if err == nil {
6769
c.channels = append(c.channels, channels...)
70+
c.nsub += len(channels)
6871
}
6972
return err
7073
}
@@ -74,6 +77,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
7477
err := c.subscribe("PSUBSCRIBE", patterns...)
7578
if err == nil {
7679
c.patterns = append(c.patterns, patterns...)
80+
c.nsub += len(patterns)
7781
}
7882
return err
7983
}
@@ -113,8 +117,12 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
113117
return err
114118
}
115119

120+
func (c *PubSub) Close() error {
121+
return c.base.Close()
122+
}
123+
116124
func (c *PubSub) Ping(payload string) error {
117-
cn, _, err := c.conn()
125+
cn, _, err := c.base.conn()
118126
if err != nil {
119127
return err
120128
}
@@ -178,7 +186,7 @@ func (p *Pong) String() string {
178186
return "Pong"
179187
}
180188

181-
func newMessage(reply []interface{}) (interface{}, error) {
189+
func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
182190
switch kind := reply[0].(string); kind {
183191
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
184192
return &Subscription{
@@ -210,7 +218,11 @@ func newMessage(reply []interface{}) (interface{}, error) {
210218
// is not received in time. This is low-level API and most clients
211219
// should use ReceiveMessage.
212220
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
213-
cn, _, err := c.conn()
221+
if c.nsub == 0 {
222+
c.resubscribe()
223+
}
224+
225+
cn, _, err := c.base.conn()
214226
if err != nil {
215227
return nil, err
216228
}
@@ -222,7 +234,8 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
222234
if err != nil {
223235
return nil, err
224236
}
225-
return newMessage(cmd.Val())
237+
238+
return c.newMessage(cmd.Val())
226239
}
227240

228241
// Receive returns a message as a Subscription, Message, PMessage,
@@ -232,22 +245,6 @@ func (c *PubSub) Receive() (interface{}, error) {
232245
return c.ReceiveTimeout(0)
233246
}
234247

235-
func (c *PubSub) reconnect(reason error) {
236-
// Close current connection.
237-
c.connPool.(*stickyConnPool).Reset(reason)
238-
239-
if len(c.channels) > 0 {
240-
if err := c.Subscribe(c.channels...); err != nil {
241-
Logger.Printf("Subscribe failed: %s", err)
242-
}
243-
}
244-
if len(c.patterns) > 0 {
245-
if err := c.PSubscribe(c.patterns...); err != nil {
246-
Logger.Printf("PSubscribe failed: %s", err)
247-
}
248-
}
249-
}
250-
251248
// ReceiveMessage returns a message or error. It automatically
252249
// reconnects to Redis in case of network errors.
253250
func (c *PubSub) ReceiveMessage() (*Message, error) {
@@ -259,27 +256,25 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
259256
return nil, err
260257
}
261258

262-
goodConn := errNum == 0
263259
errNum++
264-
265-
if goodConn {
260+
if errNum < 3 {
266261
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
267262
err := c.Ping("")
268263
if err == nil {
269264
continue
270265
}
271266
Logger.Printf("PubSub.Ping failed: %s", err)
272267
}
273-
}
274-
275-
if errNum > 2 {
268+
} else {
269+
// 3 consequent errors - connection is bad
270+
// and/or Redis Server is down.
271+
// Sleep to not exceed max number of open connections.
276272
time.Sleep(time.Second)
277273
}
278-
c.reconnect(err)
279274
continue
280275
}
281276

282-
// Reset error number.
277+
// Reset error number, because we received a message.
283278
errNum = 0
284279

285280
switch msg := msgi.(type) {
@@ -300,3 +295,22 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
300295
}
301296
}
302297
}
298+
299+
func (c *PubSub) putConn(cn *conn, err error) {
300+
if !c.base.putConn(cn, err) {
301+
c.nsub = 0
302+
}
303+
}
304+
305+
func (c *PubSub) resubscribe() {
306+
if len(c.channels) > 0 {
307+
if err := c.Subscribe(c.channels...); err != nil {
308+
Logger.Printf("Subscribe failed: %s", err)
309+
}
310+
}
311+
if len(c.patterns) > 0 {
312+
if err := c.PSubscribe(c.patterns...); err != nil {
313+
Logger.Printf("PSubscribe failed: %s", err)
314+
}
315+
}
316+
}

0 commit comments

Comments
 (0)