Skip to content

Commit 02ccf05

Browse files
committed
Close the conn on context timeout
1 parent 43d9b98 commit 02ccf05

File tree

4 files changed

+50
-21
lines changed

4 files changed

+50
-21
lines changed

error.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ func isRedisError(err error) bool {
6565
}
6666

6767
func isBadConn(err error, allowTimeout bool) bool {
68-
if err == nil {
68+
switch err {
69+
case nil:
6970
return false
71+
case context.Canceled, context.DeadlineExceeded:
72+
return true
7073
}
7174

7275
if isRedisError(err) {

main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ func redisRingOptions() *redis.RingOptions {
175175
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
176176
var wg sync.WaitGroup
177177
for _, cb := range cbs {
178+
wg.Add(n)
178179
for i := 0; i < n; i++ {
179-
wg.Add(1)
180180
go func(cb func(int), i int) {
181181
defer GinkgoRecover()
182182
defer wg.Done()

race_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package redis_test
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"net"
78
"strconv"
@@ -295,6 +296,25 @@ var _ = Describe("races", func() {
295296
Expect(err).NotTo(HaveOccurred())
296297
})
297298
})
299+
300+
It("should abort on context timeout", func() {
301+
opt := redisClusterOptions()
302+
client := cluster.newClusterClient(ctx, opt)
303+
304+
ctx, cancel := context.WithCancel(context.Background())
305+
306+
wg := performAsync(C, func(_ int) {
307+
_, err := client.XRead(ctx, &redis.XReadArgs{
308+
Streams: []string{"test", "$"},
309+
Block: 1 * time.Second,
310+
}).Result()
311+
Expect(err).To(Equal(context.Canceled))
312+
})
313+
314+
time.Sleep(10 * time.Millisecond)
315+
cancel()
316+
wg.Wait()
317+
})
298318
})
299319

300320
var _ = Describe("cluster races", func() {

redis.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"sync/atomic"
78
"time"
89

910
"github.com/go-redis/redis/v8/internal"
@@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline(
130131
}
131132

132133
func (hs hooks) withContext(ctx context.Context, fn func() error) error {
133-
done := ctx.Done()
134-
if done == nil {
135-
return fn()
136-
}
137-
138-
errc := make(chan error, 1)
139-
go func() { errc <- fn() }()
140-
141-
select {
142-
case <-done:
143-
return ctx.Err()
144-
case err := <-errc:
145-
return err
146-
}
134+
return fn()
147135
}
148136

149137
//------------------------------------------------------------------------------
@@ -316,8 +304,24 @@ func (c *baseClient) withConn(
316304
c.releaseConn(ctx, cn, err)
317305
}()
318306

319-
err = fn(ctx, cn)
320-
return err
307+
done := ctx.Done()
308+
if done == nil {
309+
err = fn(ctx, cn)
310+
return err
311+
}
312+
313+
errc := make(chan error, 1)
314+
go func() { errc <- fn(ctx, cn) }()
315+
316+
select {
317+
case <-done:
318+
_ = cn.Close()
319+
320+
err = ctx.Err()
321+
return err
322+
case err = <-errc:
323+
return err
324+
}
321325
})
322326
}
323327

@@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
334338
}
335339
}
336340

337-
retryTimeout := true
341+
retryTimeout := uint32(1)
338342
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
339343
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
340344
return writeCmd(wr, cmd)
@@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
345349

346350
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
347351
if err != nil {
348-
retryTimeout = cmd.readTimeout() == nil
352+
if cmd.readTimeout() == nil {
353+
atomic.StoreUint32(&retryTimeout, 1)
354+
}
349355
return err
350356
}
351357

@@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
354360
if err == nil {
355361
return nil
356362
}
357-
retry = shouldRetry(err, retryTimeout)
363+
retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
358364
return err
359365
})
360366
if err == nil || !retry {

0 commit comments

Comments
 (0)