Skip to content

Commit

Permalink
Merge pull request #26 from guilt/master
Browse files Browse the repository at this point in the history
Add hooks for onUnsubscribe, before unsubscribe happens.
  • Loading branch information
DrmagicE committed Sep 29, 2020
2 parents 132bfa3 + a19d6da commit 622fdd7
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -95,6 +95,7 @@ Gmqtt implements the following hooks:
* OnSessionTerminated
* OnSubscribe
* OnSubscribed
* OnUnsubscribe
* OnUnsubscribed
* OnMsgArrived
* OnAcked
Expand Down
3 changes: 2 additions & 1 deletion README_ZH.md
Expand Up @@ -94,6 +94,7 @@ Gmqtt实现了下列钩子方法
* OnSessionTerminated
* OnSubscribe
* OnSubscribed
* OnUnsubscribe
* OnUnsubscribed
* OnMsgArrived
* OnAcked
Expand Down Expand Up @@ -129,4 +130,4 @@ $ go test -race .
* 支持MQTT V3和V5
* 桥接模式,集群模式(看情况)

*暂时不保证向后兼容,在添加上述新功能时可能会有breaking changes。*
*暂时不保证向后兼容,在添加上述新功能时可能会有breaking changes。*
3 changes: 3 additions & 0 deletions client.go
Expand Up @@ -667,6 +667,9 @@ func (client *client) unsubscribeHandler(unSub *packets.Unsubscribe) {
unSuback := unSub.NewUnSubBack()
client.write(unSuback)
for _, topicName := range unSub.Topics {
if srv.hooks.OnUnsubscribe != nil {
srv.hooks.OnUnsubscribe(context.Background(), client, topicName)
}
srv.subscriptionsDB.Unsubscribe(client.opts.clientID, topicName)
if srv.hooks.OnUnsubscribed != nil {
srv.hooks.OnUnsubscribed(context.Background(), client, topicName)
Expand Down
6 changes: 6 additions & 0 deletions hook.go
Expand Up @@ -12,6 +12,7 @@ type Hooks struct {
OnStop
OnSubscribe
OnSubscribed
OnUnsubscribe
OnUnsubscribed
OnMsgArrived
OnConnect
Expand Down Expand Up @@ -55,6 +56,11 @@ type OnSubscribed func(ctx context.Context, client Client, topic packets.Topic)

type OnSubscribedWrapper func(OnSubscribed) OnSubscribed

// OnUnsubscribe will be called when the topic is being unsubscribed
type OnUnsubscribe func(ctx context.Context, client Client, topicName string)

type OnUnsubscribeWrapper func(OnUnsubscribe) OnUnsubscribe

// OnUnsubscribed will be called after the topic has been unsubscribed
type OnUnsubscribed func(ctx context.Context, client Client, topicName string)

Expand Down
1 change: 1 addition & 0 deletions plugin.go
Expand Up @@ -9,6 +9,7 @@ type HookWrapper struct {
OnSessionTerminatedWrapper OnSessionTerminatedWrapper
OnSubscribeWrapper OnSubscribeWrapper
OnSubscribedWrapper OnSubscribedWrapper
OnUnsubscribeWrapper OnUnsubscribeWrapper
OnUnsubscribedWrapper OnUnsubscribedWrapper
OnMsgArrivedWrapper OnMsgArrivedWrapper
OnAckedWrapper OnAckedWrapper
Expand Down
24 changes: 24 additions & 0 deletions server.go
Expand Up @@ -689,6 +689,7 @@ func (srv *server) loadPlugins() error {
onSessionTerminatedWrapper []OnSessionTerminatedWrapper
onSubscribeWrappers []OnSubscribeWrapper
onSubscribedWrappers []OnSubscribedWrapper
onUnsubscribeWrappers []OnUnsubscribeWrapper
onUnsubscribedWrappers []OnUnsubscribedWrapper
onMsgArrivedWrappers []OnMsgArrivedWrapper
onDeliverWrappers []OnDeliverWrapper
Expand Down Expand Up @@ -729,6 +730,9 @@ func (srv *server) loadPlugins() error {
if hooks.OnSubscribedWrapper != nil {
onSubscribedWrappers = append(onSubscribedWrappers, hooks.OnSubscribedWrapper)
}
if hooks.OnUnsubscribeWrapper != nil {
onUnsubscribeWrappers = append(onUnsubscribeWrappers, hooks.OnUnsubscribeWrapper)
}
if hooks.OnUnsubscribedWrapper != nil {
onUnsubscribedWrappers = append(onUnsubscribedWrappers, hooks.OnUnsubscribedWrapper)
}
Expand All @@ -751,6 +755,7 @@ func (srv *server) loadPlugins() error {
onStopWrappers = append(onStopWrappers, hooks.OnStopWrapper)
}
}

// onAccept
if onAcceptWrappers != nil {
onAccept := func(ctx context.Context, conn net.Conn) bool {
Expand All @@ -761,6 +766,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnAccept = onAccept
}

// onConnect
if onConnectWrappers != nil {
onConnect := func(ctx context.Context, client Client) (code uint8) {
Expand All @@ -771,6 +777,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnConnect = onConnect
}

// onConnected
if onConnectedWrappers != nil {
onConnected := func(ctx context.Context, client Client) {}
Expand All @@ -779,6 +786,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnConnected = onConnected
}

// onSessionCreated
if onSessionCreatedWrapper != nil {
onSessionCreated := func(ctx context.Context, client Client) {}
Expand Down Expand Up @@ -816,6 +824,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnSubscribe = onSubscribe
}

// onSubscribed
if onSubscribedWrappers != nil {
onSubscribed := func(ctx context.Context, client Client, topic packets.Topic) {}
Expand All @@ -824,6 +833,16 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnSubscribed = onSubscribed
}

//onUnsubscribe
if onUnsubscribeWrappers != nil {
onUnsubscribe := func(ctx context.Context, client Client, topicName string) {}
for i := len(onUnsubscribeWrappers); i > 0; i-- {
onUnsubscribe = onUnsubscribeWrappers[i-1](onUnsubscribe)
}
srv.hooks.OnUnsubscribe = onUnsubscribe
}

//onUnsubscribed
if onUnsubscribedWrappers != nil {
onUnsubscribed := func(ctx context.Context, client Client, topicName string) {}
Expand All @@ -832,6 +851,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnUnsubscribed = onUnsubscribed
}

// onMsgArrived
if onMsgArrivedWrappers != nil {
onMsgArrived := func(ctx context.Context, client Client, msg packets.Message) (valid bool) {
Expand All @@ -842,6 +862,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnMsgArrived = onMsgArrived
}

// onDeliver
if onDeliverWrappers != nil {
onDeliver := func(ctx context.Context, client Client, msg packets.Message) {}
Expand All @@ -850,6 +871,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnDeliver = onDeliver
}

// onAcked
if onAckedWrappers != nil {
onAcked := func(ctx context.Context, client Client, msg packets.Message) {}
Expand All @@ -858,6 +880,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnAcked = onAcked
}

// onClose hooks
if onCloseWrappers != nil {
onClose := func(ctx context.Context, client Client, err error) {}
Expand All @@ -866,6 +889,7 @@ func (srv *server) loadPlugins() error {
}
srv.hooks.OnClose = onClose
}

// onStop
if onStopWrappers != nil {
onStop := func(ctx context.Context) {}
Expand Down

0 comments on commit 622fdd7

Please sign in to comment.