Skip to content

Commit

Permalink
Add websocket keepalive support
Browse files Browse the repository at this point in the history
  • Loading branch information
sinamt committed Feb 5, 2019
1 parent 055fb4b commit 693753f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
28 changes: 19 additions & 9 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
Expand All @@ -25,15 +26,16 @@ type params struct {
}

type Config struct {
cacheSize int
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
tracer graphql.Tracer
complexityLimit int
disableIntrospection bool
cacheSize int
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
tracer graphql.Tracer
complexityLimit int
disableIntrospection bool
connectionKeepAliveTimeout time.Duration
}

func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
Expand Down Expand Up @@ -241,6 +243,14 @@ func CacheSize(size int) Option {

const DefaultCacheSize = 1000

// WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior.
// By default, keep-alive is disabled.
func WebsocketKeepAliveDuration(duration time.Duration) Option {
return func(cfg *Config) {
cfg.connectionKeepAliveTimeout = duration
}
}

func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
cfg := &Config{
cacheSize: DefaultCacheSize,
Expand Down
52 changes: 44 additions & 8 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"
"sync"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
Expand All @@ -28,7 +29,7 @@ const (
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
//connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
connectionKeepAliveMsg = "ka" // Server -> Client
)

type operationMessage struct {
Expand All @@ -38,13 +39,14 @@ type operationMessage struct {
}

type wsConnection struct {
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
cache *lru.Cache
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
cache *lru.Cache
keepAliveTimer *time.Timer

initPayload InitPayload
}
Expand Down Expand Up @@ -108,10 +110,28 @@ func (c *wsConnection) init() bool {
func (c *wsConnection) write(msg *operationMessage) {
c.mu.Lock()
c.conn.WriteJSON(msg)
if c.cfg.connectionKeepAliveTimeout != 0 && c.keepAliveTimer != nil {
c.keepAliveTimer.Reset(c.cfg.connectionKeepAliveTimeout)
}
c.mu.Unlock()
}

func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()

// Create a timer that will fire every interval if a write hasn't been made
// to keep the connection alive.
if c.cfg.connectionKeepAliveTimeout != 0 {
c.mu.Lock()
c.keepAliveTimer = time.NewTimer(c.cfg.connectionKeepAliveTimeout)
c.mu.Unlock()

go c.keepAlive(ctx)
}

for {
message := c.readOp()
if message == nil {
Expand Down Expand Up @@ -144,6 +164,22 @@ func (c *wsConnection) run() {
}
}

func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
if !c.keepAliveTimer.Stop() {
<-c.keepAliveTimer.C
}
return
case <-c.keepAliveTimer.C:
// We don't reset the timer here, because the `c.write` command
// will reset the timer anyways.
c.write(&operationMessage{Type: connectionKeepAliveMsg})
}
}
}

func (c *wsConnection) subscribe(message *operationMessage) bool {
var reqParams params
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
Expand Down
36 changes: 36 additions & 0 deletions handler/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -122,6 +123,41 @@ func TestWebsocket(t *testing.T) {
})
}

func TestWebsocketWithKeepAlive(t *testing.T) {
next := make(chan struct{})
h := GraphQL(&executableSchemaStub{next}, WebsocketKeepAliveDuration(10*time.Millisecond))

srv := httptest.NewServer(h)
defer srv.Close()

t.Run("client must receive keepalive", func(t *testing.T) {
c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
require.Equal(t, connectionAckMsg, readOp(c).Type)

require.NoError(t, c.WriteJSON(&operationMessage{
Type: startMsg,
ID: "test_1",
Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`),
}))

// keepalive
msg := readOp(c)
require.Equal(t, connectionKeepAliveMsg, msg.Type)

// server message
next <- struct{}{}
msg = readOp(c)
require.Equal(t, dataMsg, msg.Type)

// keepalive
msg = readOp(c)
require.Equal(t, connectionKeepAliveMsg, msg.Type)
})
}

func wsConnect(url string) *websocket.Conn {
c, _, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
if err != nil {
Expand Down

0 comments on commit 693753f

Please sign in to comment.