Skip to content

Commit baa48a4

Browse files
committed
feat(pubsub): support sharded pub/sub
1 parent 084c0c8 commit baa48a4

File tree

8 files changed

+233
-2
lines changed

8 files changed

+233
-2
lines changed

cluster.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,16 @@ func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *Pub
15281528
return pubsub
15291529
}
15301530

1531+
// SSubscribe Subscribes the client to the specified shard channels.
1532+
func (c *ClusterClient) SSubscribe(ctx context.Context, channels ...string) *PubSub {
1533+
pubsub := c.pubSub()
1534+
if len(channels) > 0 {
1535+
_ = pubsub.SSubscribe(ctx, channels...)
1536+
}
1537+
return pubsub
1538+
}
1539+
1540+
15311541
func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
15321542
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
15331543
}

cluster_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,30 @@ var _ = Describe("ClusterClient", func() {
549549
}, 30*time.Second).ShouldNot(HaveOccurred())
550550
})
551551

552+
It("supports sharded PubSub", func() {
553+
pubsub := client.SSubscribe(ctx, "mychannel")
554+
defer pubsub.Close()
555+
556+
Eventually(func() error {
557+
_, err := client.SPublish(ctx, "mychannel", "hello").Result()
558+
if err != nil {
559+
return err
560+
}
561+
562+
msg, err := pubsub.ReceiveTimeout(ctx, time.Second)
563+
if err != nil {
564+
return err
565+
}
566+
567+
_, ok := msg.(*redis.Message)
568+
if !ok {
569+
return fmt.Errorf("got %T, wanted *redis.Message", msg)
570+
}
571+
572+
return nil
573+
}, 30*time.Second).ShouldNot(HaveOccurred())
574+
})
575+
552576
It("supports PubSub.Ping without channels", func() {
553577
pubsub := client.Subscribe(ctx)
554578
defer pubsub.Close()

commands.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,12 @@ type Cmdable interface {
345345
ScriptLoad(ctx context.Context, script string) *StringCmd
346346

347347
Publish(ctx context.Context, channel string, message interface{}) *IntCmd
348+
SPublish(ctx context.Context, channel string, message interface{}) *IntCmd
348349
PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd
349350
PubSubNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
350351
PubSubNumPat(ctx context.Context) *IntCmd
352+
PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd
353+
PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
351354

352355
ClusterSlots(ctx context.Context) *ClusterSlotsCmd
353356
ClusterNodes(ctx context.Context) *StringCmd
@@ -3078,6 +3081,12 @@ func (c cmdable) Publish(ctx context.Context, channel string, message interface{
30783081
return cmd
30793082
}
30803083

3084+
func (c cmdable) SPublish(ctx context.Context, channel string, message interface{}) *IntCmd {
3085+
cmd := NewIntCmd(ctx, "spublish", channel, message)
3086+
_ = c(ctx, cmd)
3087+
return cmd
3088+
}
3089+
30813090
func (c cmdable) PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd {
30823091
args := []interface{}{"pubsub", "channels"}
30833092
if pattern != "*" {
@@ -3100,6 +3109,28 @@ func (c cmdable) PubSubNumSub(ctx context.Context, channels ...string) *StringIn
31003109
return cmd
31013110
}
31023111

3112+
func (c cmdable) PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd {
3113+
args := []interface{}{"pubsub", "shardchannels"}
3114+
if pattern != "*" {
3115+
args = append(args, pattern)
3116+
}
3117+
cmd := NewStringSliceCmd(ctx, args...)
3118+
_ = c(ctx, cmd)
3119+
return cmd
3120+
}
3121+
3122+
func (c cmdable) PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd {
3123+
args := make([]interface{}, 2+len(channels))
3124+
args[0] = "pubsub"
3125+
args[1] = "shardnumsub"
3126+
for i, channel := range channels {
3127+
args[2+i] = channel
3128+
}
3129+
cmd := NewStringIntMapCmd(ctx, args...)
3130+
_ = c(ctx, cmd)
3131+
return cmd
3132+
}
3133+
31033134
func (c cmdable) PubSubNumPat(ctx context.Context) *IntCmd {
31043135
cmd := NewIntCmd(ctx, "pubsub", "numpat")
31053136
_ = c(ctx, cmd)

pubsub.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type PubSub struct {
2828
cn *pool.Conn
2929
channels map[string]struct{}
3030
patterns map[string]struct{}
31+
schannels map[string]struct{}
3132

3233
closed bool
3334
exit chan struct{}
@@ -46,6 +47,7 @@ func (c *PubSub) init() {
4647
func (c *PubSub) String() string {
4748
channels := mapKeys(c.channels)
4849
channels = append(channels, mapKeys(c.patterns)...)
50+
channels = append(channels, mapKeys(c.schannels)...)
4951
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
5052
}
5153

@@ -101,6 +103,13 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
101103
}
102104
}
103105

106+
if len(c.schannels) > 0 {
107+
err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels))
108+
if err != nil && firstErr == nil {
109+
firstErr = err
110+
}
111+
}
112+
104113
return firstErr
105114
}
106115

@@ -208,6 +217,21 @@ func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
208217
return err
209218
}
210219

220+
// SSubscribe Subscribes the client to the specified shard channels.
221+
func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error {
222+
c.mu.Lock()
223+
defer c.mu.Unlock()
224+
225+
err := c.subscribe(ctx, "ssubscribe", channels...)
226+
if c.schannels == nil {
227+
c.schannels = make(map[string]struct{})
228+
}
229+
for _, s := range channels {
230+
c.schannels[s] = struct{}{}
231+
}
232+
return err
233+
}
234+
211235
// Unsubscribe the client from the given channels, or from all of
212236
// them if none is given.
213237
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
@@ -234,6 +258,19 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
234258
return err
235259
}
236260

261+
// SUnsubscribe unsubscribes the client from the given shard channels,
262+
// or from all of them if none is given.
263+
func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error {
264+
c.mu.Lock()
265+
defer c.mu.Unlock()
266+
267+
for _, channel := range channels {
268+
delete(c.schannels, channel)
269+
}
270+
err := c.subscribe(ctx, "sunsubscribe", channels...)
271+
return err
272+
}
273+
237274
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
238275
cn, err := c.conn(ctx, channels)
239276
if err != nil {
@@ -311,15 +348,15 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
311348
}, nil
312349
case []interface{}:
313350
switch kind := reply[0].(string); kind {
314-
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
351+
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe":
315352
// Can be nil in case of "unsubscribe".
316353
channel, _ := reply[1].(string)
317354
return &Subscription{
318355
Kind: kind,
319356
Channel: channel,
320357
Count: int(reply[2].(int64)),
321358
}, nil
322-
case "message":
359+
case "message", "smessage":
323360
switch payload := reply[2].(type) {
324361
case string:
325362
return &Message{

pubsub_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,35 @@ var _ = Describe("PubSub", func() {
102102
Expect(len(channels)).To(BeNumerically(">=", 2))
103103
})
104104

105+
It("should sharded pub/sub channels", func() {
106+
channels, err := client.PubSubShardChannels(ctx, "mychannel*").Result()
107+
Expect(err).NotTo(HaveOccurred())
108+
Expect(channels).To(BeEmpty())
109+
110+
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
111+
defer pubsub.Close()
112+
113+
channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result()
114+
Expect(err).NotTo(HaveOccurred())
115+
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
116+
117+
channels, err = client.PubSubShardChannels(ctx, "").Result()
118+
Expect(err).NotTo(HaveOccurred())
119+
Expect(channels).To(BeEmpty())
120+
121+
channels, err = client.PubSubShardChannels(ctx, "*").Result()
122+
Expect(err).NotTo(HaveOccurred())
123+
Expect(len(channels)).To(BeNumerically(">=", 2))
124+
125+
nums, err := client.PubSubShardNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result()
126+
Expect(err).NotTo(HaveOccurred())
127+
Expect(nums).To(Equal(map[string]int64{
128+
"mychannel": 1,
129+
"mychannel2": 1,
130+
"mychannel3": 0,
131+
}))
132+
})
133+
105134
It("should return the numbers of subscribers", func() {
106135
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
107136
defer pubsub.Close()
@@ -204,6 +233,82 @@ var _ = Describe("PubSub", func() {
204233
Expect(stats.Misses).To(Equal(uint32(1)))
205234
})
206235

236+
It("should sharded pub/sub", func() {
237+
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
238+
defer pubsub.Close()
239+
240+
{
241+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
242+
Expect(err).NotTo(HaveOccurred())
243+
subscr := msgi.(*redis.Subscription)
244+
Expect(subscr.Kind).To(Equal("ssubscribe"))
245+
Expect(subscr.Channel).To(Equal("mychannel"))
246+
Expect(subscr.Count).To(Equal(1))
247+
}
248+
249+
{
250+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
251+
Expect(err).NotTo(HaveOccurred())
252+
subscr := msgi.(*redis.Subscription)
253+
Expect(subscr.Kind).To(Equal("ssubscribe"))
254+
Expect(subscr.Channel).To(Equal("mychannel2"))
255+
Expect(subscr.Count).To(Equal(2))
256+
}
257+
258+
{
259+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
260+
Expect(err.(net.Error).Timeout()).To(Equal(true))
261+
Expect(msgi).NotTo(HaveOccurred())
262+
}
263+
264+
n, err := client.SPublish(ctx, "mychannel", "hello").Result()
265+
Expect(err).NotTo(HaveOccurred())
266+
Expect(n).To(Equal(int64(1)))
267+
268+
n, err = client.SPublish(ctx, "mychannel2", "hello2").Result()
269+
Expect(err).NotTo(HaveOccurred())
270+
Expect(n).To(Equal(int64(1)))
271+
272+
Expect(pubsub.SUnsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred())
273+
274+
{
275+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
276+
Expect(err).NotTo(HaveOccurred())
277+
msg := msgi.(*redis.Message)
278+
Expect(msg.Channel).To(Equal("mychannel"))
279+
Expect(msg.Payload).To(Equal("hello"))
280+
}
281+
282+
{
283+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
284+
Expect(err).NotTo(HaveOccurred())
285+
msg := msgi.(*redis.Message)
286+
Expect(msg.Channel).To(Equal("mychannel2"))
287+
Expect(msg.Payload).To(Equal("hello2"))
288+
}
289+
290+
{
291+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
292+
Expect(err).NotTo(HaveOccurred())
293+
subscr := msgi.(*redis.Subscription)
294+
Expect(subscr.Kind).To(Equal("sunsubscribe"))
295+
Expect(subscr.Channel).To(Equal("mychannel"))
296+
Expect(subscr.Count).To(Equal(1))
297+
}
298+
299+
{
300+
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
301+
Expect(err).NotTo(HaveOccurred())
302+
subscr := msgi.(*redis.Subscription)
303+
Expect(subscr.Kind).To(Equal("sunsubscribe"))
304+
Expect(subscr.Channel).To(Equal("mychannel2"))
305+
Expect(subscr.Count).To(Equal(0))
306+
}
307+
308+
stats := client.PoolStats()
309+
Expect(stats.Misses).To(Equal(uint32(1)))
310+
})
311+
207312
It("should ping/pong", func() {
208313
pubsub := client.Subscribe(ctx, "mychannel")
209314
defer pubsub.Close()

redis.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,16 @@ func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
691691
return pubsub
692692
}
693693

694+
// SSubscribe Subscribes the client to the specified shard channels.
695+
// Channels can be omitted to create empty subscription.
696+
func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub {
697+
pubsub := c.pubSub()
698+
if len(channels) > 0 {
699+
_ = pubsub.SSubscribe(ctx, channels...)
700+
}
701+
return pubsub
702+
}
703+
694704
//------------------------------------------------------------------------------
695705

696706
type conn struct {

ring.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,19 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
504504
return shard.Client.PSubscribe(ctx, channels...)
505505
}
506506

507+
// SSubscribe Subscribes the client to the specified shard channels.
508+
func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
509+
if len(channels) == 0 {
510+
panic("at least one channel is required")
511+
}
512+
shard, err := c.shards.GetByKey(channels[0])
513+
if err != nil {
514+
// TODO: return PubSub with sticky error
515+
panic(err)
516+
}
517+
return shard.Client.SSubscribe(ctx, channels...)
518+
}
519+
507520
// ForEachShard concurrently calls the fn on each live shard in the ring.
508521
// It returns the first error if any.
509522
func (c *Ring) ForEachShard(

universal.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ type UniversalClient interface {
190190
Process(ctx context.Context, cmd Cmder) error
191191
Subscribe(ctx context.Context, channels ...string) *PubSub
192192
PSubscribe(ctx context.Context, channels ...string) *PubSub
193+
SSubscribe(ctx context.Context, channels ...string) *PubSub
193194
Close() error
194195
PoolStats() *PoolStats
195196
}

0 commit comments

Comments
 (0)