Skip to content

Commit 2ec03d9

Browse files
committed
fix: late binding for dial hook
1 parent 180f107 commit 2ec03d9

File tree

5 files changed

+51
-50
lines changed

5 files changed

+51
-50
lines changed

cluster.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
845845
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
846846
c.cmdable = c.Process
847847

848-
c.hooks.process = c.process
849-
c.hooks.processPipeline = c._processPipeline
850-
c.hooks.processTxPipeline = c._processTxPipeline
848+
c.hooks.setProcess(c.process)
849+
c.hooks.setProcessPipeline(c._processPipeline)
850+
c.hooks.setProcessTxPipeline(c._processTxPipeline)
851851

852852
return c
853853
}

extra/redisotel/tracing.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,7 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
8989
return hook(ctx, network, addr)
9090
}
9191

92-
spanOpts := th.spanOpts
93-
spanOpts = append(spanOpts, trace.WithAttributes(
94-
attribute.String("network", network),
95-
attribute.String("addr", addr),
96-
))
97-
98-
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", spanOpts...)
92+
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...)
9993
defer span.End()
10094

10195
conn, err := hook(ctx, network, addr)

redis.go

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,19 @@ type (
3737
)
3838

3939
type hooks struct {
40-
slice []Hook
41-
dial DialHook
42-
process ProcessHook
43-
processPipeline ProcessPipelineHook
44-
processTxPipeline ProcessPipelineHook
40+
slice []Hook
41+
dialHook DialHook
42+
processHook ProcessHook
43+
processPipelineHook ProcessPipelineHook
44+
processTxPipelineHook ProcessPipelineHook
4545
}
4646

4747
func (hs *hooks) AddHook(hook Hook) {
48-
if hs.process == nil {
49-
panic("hs.process == nil")
50-
}
51-
if hs.processPipeline == nil {
52-
panic("hs.processPipeline == nil")
53-
}
54-
if hs.processTxPipeline == nil {
55-
panic("hs.processTxPipeline == nil")
56-
}
57-
5848
hs.slice = append(hs.slice, hook)
59-
hs.dial = hook.DialHook(hs.dial)
60-
hs.process = hook.ProcessHook(hs.process)
61-
hs.processPipeline = hook.ProcessPipelineHook(hs.processPipeline)
62-
hs.processTxPipeline = hook.ProcessPipelineHook(hs.processTxPipeline)
49+
hs.dialHook = hook.DialHook(hs.dialHook)
50+
hs.processHook = hook.ProcessHook(hs.processHook)
51+
hs.processPipelineHook = hook.ProcessPipelineHook(hs.processPipelineHook)
52+
hs.processTxPipelineHook = hook.ProcessPipelineHook(hs.processTxPipelineHook)
6353
}
6454

6555
func (hs *hooks) clone() hooks {
@@ -70,37 +60,37 @@ func (hs *hooks) clone() hooks {
7060
}
7161

7262
func (hs *hooks) setDial(dial DialHook) {
73-
hs.dial = dial
63+
hs.dialHook = dial
7464
for _, h := range hs.slice {
75-
if wrapped := h.DialHook(hs.dial); wrapped != nil {
76-
hs.dial = wrapped
65+
if wrapped := h.DialHook(hs.dialHook); wrapped != nil {
66+
hs.dialHook = wrapped
7767
}
7868
}
7969
}
8070

8171
func (hs *hooks) setProcess(process ProcessHook) {
82-
hs.process = process
72+
hs.processHook = process
8373
for _, h := range hs.slice {
84-
if wrapped := h.ProcessHook(hs.process); wrapped != nil {
85-
hs.process = wrapped
74+
if wrapped := h.ProcessHook(hs.processHook); wrapped != nil {
75+
hs.processHook = wrapped
8676
}
8777
}
8878
}
8979

9080
func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) {
91-
hs.processPipeline = processPipeline
81+
hs.processPipelineHook = processPipeline
9282
for _, h := range hs.slice {
93-
if wrapped := h.ProcessPipelineHook(hs.processPipeline); wrapped != nil {
94-
hs.processPipeline = wrapped
83+
if wrapped := h.ProcessPipelineHook(hs.processPipelineHook); wrapped != nil {
84+
hs.processPipelineHook = wrapped
9585
}
9686
}
9787
}
9888

9989
func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) {
100-
hs.processTxPipeline = processTxPipeline
90+
hs.processTxPipelineHook = processTxPipeline
10191
for _, h := range hs.slice {
102-
if wrapped := h.ProcessPipelineHook(hs.processTxPipeline); wrapped != nil {
103-
hs.processTxPipeline = wrapped
92+
if wrapped := h.ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil {
93+
hs.processTxPipelineHook = wrapped
10494
}
10595
}
10696
}
@@ -125,6 +115,22 @@ func (hs *hooks) withProcessPipelineHook(
125115
return hook(ctx, cmds)
126116
}
127117

118+
func (hs *hooks) dial(ctx context.Context, network, addr string) (net.Conn, error) {
119+
return hs.dialHook(ctx, network, addr)
120+
}
121+
122+
func (hs *hooks) process(ctx context.Context, cmd Cmder) error {
123+
return hs.processHook(ctx, cmd)
124+
}
125+
126+
func (hs *hooks) processPipeline(ctx context.Context, cmds []Cmder) error {
127+
return hs.processPipelineHook(ctx, cmds)
128+
}
129+
130+
func (hs *hooks) processTxPipeline(ctx context.Context, cmds []Cmder) error {
131+
return hs.processTxPipelineHook(ctx, cmds)
132+
}
133+
128134
//------------------------------------------------------------------------------
129135

130136
type baseClient struct {
@@ -538,8 +544,8 @@ func NewClient(opt *Options) *Client {
538544
opt: opt,
539545
},
540546
}
541-
c.connPool = newConnPool(opt, c.baseClient.dial)
542547
c.init()
548+
c.connPool = newConnPool(opt, c.hooks.dial)
543549

544550
return &c
545551
}

ring.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,13 +495,13 @@ func NewRing(opt *RingOptions) *Ring {
495495
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
496496
ring.cmdable = ring.Process
497497

498-
ring.hooks.process = ring.process
499-
ring.hooks.processPipeline = func(ctx context.Context, cmds []Cmder) error {
498+
ring.hooks.setProcess(ring.process)
499+
ring.hooks.setProcessPipeline(func(ctx context.Context, cmds []Cmder) error {
500500
return ring.generalProcessPipeline(ctx, cmds, false)
501-
}
502-
ring.hooks.processTxPipeline = func(ctx context.Context, cmds []Cmder) error {
501+
})
502+
ring.hooks.setProcessTxPipeline(func(ctx context.Context, cmds []Cmder) error {
503503
return ring.generalProcessPipeline(ctx, cmds, true)
504-
}
504+
})
505505

506506
go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency)
507507

sentinel.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,11 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
205205
opt: opt,
206206
},
207207
}
208-
connPool = newConnPool(opt, rdb.baseClient.dial)
208+
rdb.init()
209+
210+
connPool = newConnPool(opt, rdb.hooks.dial)
209211
rdb.connPool = connPool
210212
rdb.onClose = failover.Close
211-
rdb.init()
212213

213214
failover.mu.Lock()
214215
failover.onFailover = func(ctx context.Context, addr string) {
@@ -269,10 +270,10 @@ func NewSentinelClient(opt *Options) *SentinelClient {
269270
opt: opt,
270271
},
271272
}
272-
c.connPool = newConnPool(opt, c.baseClient.dial)
273273

274274
c.hooks.setDial(c.baseClient.dial)
275275
c.hooks.setProcess(c.baseClient.process)
276+
c.connPool = newConnPool(opt, c.hooks.dial)
276277

277278
return c
278279
}

0 commit comments

Comments
 (0)