Skip to content

Commit

Permalink
feat: add a close handler to client
Browse files Browse the repository at this point in the history
  • Loading branch information
brianmcgee committed Sep 10, 2022
1 parent dbc4d23 commit 39bc726
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
26 changes: 20 additions & 6 deletions client.go
Expand Up @@ -22,6 +22,7 @@ var (
type (
ResponseFuture = async.Future[async.Result[*Response]]
RequestHandler = func(req Request)
CloseHandler = func(err error)
)

type Client interface {
Expand All @@ -31,18 +32,21 @@ type Client interface {
SendContext(ctx context.Context, req Request, resp *Response) error
SendAsync(req Request) ResponseFuture

SetCloseHandler(handler CloseHandler)
SetRequestHandler(handler RequestHandler)

Close() error
}

type client struct {
dialer Dialer
conn Connection
inFlight sync.Map
log *log.Entry
closed atomic.Bool
reqHandler RequestHandler
dialer Dialer
conn Connection
inFlight sync.Map
log *log.Entry
closed atomic.Bool
reqHandler RequestHandler
closeError error
closeHandler CloseHandler
}

func NewClient(dialer Dialer) Client {
Expand Down Expand Up @@ -70,13 +74,18 @@ func (c *client) SetRequestHandler(handler RequestHandler) {
c.reqHandler = handler
}

func (c *client) SetCloseHandler(handler CloseHandler) {
c.closeHandler = handler
}

func (c *client) readMessages() {
for !c.closed.Load() {
// read the next response
bytes, err := c.conn.Read()
if err != nil {
// set the client has closed and break out of the read loop
if err == ErrClosed {
c.closeError = err
c.Close()
break
}
Expand Down Expand Up @@ -123,6 +132,11 @@ func (c *client) Close() error {
value.(ResponseFuture).Set(async.NewResultErr[*Response](ErrClosed))
return true
})

if c.closeHandler != nil {
c.closeHandler(c.closeError)
}

return nil
} else {
return ErrClosed
Expand Down
9 changes: 9 additions & 0 deletions client_test.go
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"encoding/json"
"sync/atomic"
"testing"

"github.com/gorilla/websocket"
Expand All @@ -27,6 +28,13 @@ func TestClient_ServerDisconnect(t *testing.T) {

dialer := WebSocketDialer{Url: srv.url("/ws")}
client := NewClient(dialer)

// capture close errors
closeError := atomic.Value{}
client.SetCloseHandler(func(err error) {
closeError.Store(err)
})

err := client.Connect()
assert.Nil(t, err)

Expand All @@ -36,6 +44,7 @@ func TestClient_ServerDisconnect(t *testing.T) {
var resp Response
err = client.Send(*req, &resp)
assert.Error(t, ErrClosed, err)
assert.Equal(t, ErrClosed, closeError.Load())
}

func TestClient_RequestIdMatching(t *testing.T) {
Expand Down

0 comments on commit 39bc726

Please sign in to comment.