Skip to content

Commit

Permalink
add ws init func (TT-9257) (#366)
Browse files Browse the repository at this point in the history
- sync wg commit:
wundergraph#406

---------

Co-authored-by: chedom <domanchukits@gmail.com>
Co-authored-by: Jens Neuse <jens.neuse@gmx.de>
Co-authored-by: Sergey Petrunin <spetrunin@users.noreply.github.com>
  • Loading branch information
4 people committed Jul 3, 2023
1 parent 7adc3c0 commit bfd88c5
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 27 deletions.
15 changes: 13 additions & 2 deletions pkg/http/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,14 @@ func (w *WebsocketSubscriptionClient) isClosedConnectionError(err error) bool {
return w.isClosedConnection
}

func HandleWebsocket(done chan bool, errChan chan error, conn net.Conn, executorPool subscription.ExecutorPool, logger abstractlogger.Logger) {
func HandleWebsocketWithInitFunc(
done chan bool,
errChan chan error,
conn net.Conn,
executorPool subscription.ExecutorPool,
logger abstractlogger.Logger,
initFunc subscription.WebsocketInitFunc,
) {
defer func() {
if err := conn.Close(); err != nil {
logger.Error("http.HandleWebsocket()",
Expand All @@ -128,7 +135,7 @@ func HandleWebsocket(done chan bool, errChan chan error, conn net.Conn, executor
}()

websocketClient := NewWebsocketSubscriptionClient(logger, conn)
subscriptionHandler, err := subscription.NewHandler(logger, websocketClient, executorPool)
subscriptionHandler, err := subscription.NewHandlerWithInitFunc(logger, websocketClient, executorPool, initFunc)
if err != nil {
logger.Error("http.HandleWebsocket()",
abstractlogger.String("message", "could not create subscriptionHandler"),
Expand All @@ -143,6 +150,10 @@ func HandleWebsocket(done chan bool, errChan chan error, conn net.Conn, executor
subscriptionHandler.Handle(context.Background()) // Blocking
}

func HandleWebsocket(done chan bool, errChan chan error, conn net.Conn, executorPool subscription.ExecutorPool, logger abstractlogger.Logger) {
HandleWebsocketWithInitFunc(done, errChan, conn, executorPool, logger, nil)
}

// handleWebsocket will handle the websocket connection.
func (g *GraphQLHTTPRequestHandler) handleWebsocket(conn net.Conn) {
done := make(chan bool)
Expand Down
4 changes: 2 additions & 2 deletions pkg/subscription/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ func NewInitialHttpRequestContext(r *http.Request) *InitialHttpRequestContext {

type subscriptionCancellations map[string]context.CancelFunc

func (sc subscriptionCancellations) Add(id string) (context.Context, error) {
func (sc subscriptionCancellations) AddWithParent(id string, parent context.Context) (context.Context, error) {
_, ok := sc[id]
if ok {
return nil, fmt.Errorf("subscriber for %s already exists", id)
}

ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithCancel(parent)
sc[id] = cancelFunc
return ctx, nil
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/subscription/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestSubscriptionCancellations(t *testing.T) {
t.Run("should add a cancellation func to map", func(t *testing.T) {
require.Equal(t, 0, len(cancellations))

ctx, err = cancellations.Add("1")
ctx, err = cancellations.AddWithParent("1", context.Background())
assert.Nil(t, err)
assert.Equal(t, 1, len(cancellations))
assert.NotNil(t, ctx)
Expand All @@ -56,17 +56,17 @@ func TestSubscriptionIdsShouldBeUnique(t *testing.T) {
var ctx context.Context
var err error

ctx, err = cancellations.Add("1")
ctx, err = cancellations.AddWithParent("1", context.Background())
assert.Nil(t, err)
assert.Equal(t, 1, len(cancellations))
assert.NotNil(t, ctx)

ctx, err = cancellations.Add("2")
ctx, err = cancellations.AddWithParent("2", context.Background())
assert.Nil(t, err)
assert.Equal(t, 2, len(cancellations))
assert.NotNil(t, ctx)

ctx, err = cancellations.Add("2")
ctx, err = cancellations.AddWithParent("2", context.Background())
assert.NotNil(t, err)
assert.Equal(t, 2, len(cancellations))
assert.Nil(t, ctx)
Expand Down
92 changes: 78 additions & 14 deletions pkg/subscription/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ type Executor interface {
Reset()
}

// WebsocketInitFunc is called when the server receives connection init message from the client.
// This can be used to check initial payload to see whether to accept the websocket connection.
type WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)

// Handler is the actual subscription handler which will keep track on how to handle messages coming from the client.
type Handler struct {
logger abstractlogger.Logger
Expand All @@ -78,10 +82,16 @@ type Handler struct {
executorPool ExecutorPool
// bufferPool will hold buffers.
bufferPool *sync.Pool
// initFunc will check initial payload to see whether to accept the websocket connection.
initFunc WebsocketInitFunc
}

// NewHandler creates a new subscription handler.
func NewHandler(logger abstractlogger.Logger, client Client, executorPool ExecutorPool) (*Handler, error) {
func NewHandlerWithInitFunc(
logger abstractlogger.Logger,
client Client,
executorPool ExecutorPool,
initFunc WebsocketInitFunc,
) (*Handler, error) {
keepAliveInterval, err := time.ParseDuration(DefaultKeepAliveInterval)
if err != nil {
return nil, err
Expand All @@ -105,9 +115,15 @@ func NewHandler(logger abstractlogger.Logger, client Client, executorPool Execut
return &writer
},
},
initFunc: initFunc,
}, nil
}

// NewHandler creates a new subscription handler.
func NewHandler(logger abstractlogger.Logger, client Client, executorPool ExecutorPool) (*Handler, error) {
return NewHandlerWithInitFunc(logger, client, executorPool, nil)
}

// Handle will handle the subscription connection.
func (h *Handler) Handle(ctx context.Context) {
defer func() {
Expand All @@ -134,10 +150,15 @@ func (h *Handler) Handle(ctx context.Context) {
} else if message != nil {
switch message.Type {
case MessageTypeConnectionInit:
h.handleInit()
ctx, err = h.handleInit(ctx, message.Payload)
if err != nil {
h.terminateConnection("failed to accept the websocket connection")
return
}

go h.handleKeepAlive(ctx)
case MessageTypeStart:
h.handleStart(message.Id, message.Payload)
h.handleStart(ctx, message.Id, message.Payload)
case MessageTypeStop:
h.handleStop(message.Id)
case MessageTypeConnectionTerminate:
Expand Down Expand Up @@ -166,21 +187,34 @@ func (h *Handler) ChangeSubscriptionUpdateInterval(d time.Duration) {
}

// handleInit will handle an init message.
func (h *Handler) handleInit() {
func (h *Handler) handleInit(ctx context.Context, payload []byte) (extendedCtx context.Context, err error) {
if h.initFunc != nil {
var initPayload InitPayload
// decode initial payload
if len(payload) > 0 {
initPayload = payload
}
// check initial payload to see whether to accept the websocket connection
if extendedCtx, err = h.initFunc(ctx, initPayload); err != nil {
return extendedCtx, err
}
} else {
extendedCtx = ctx
}

ackMessage := Message{
Type: MessageTypeConnectionAck,
}

err := h.client.WriteToClient(ackMessage)
if err != nil {
h.logger.Error("subscription.Handler.handleInit()",
abstractlogger.Error(err),
)
if err = h.client.WriteToClient(ackMessage); err != nil {
return extendedCtx, err
}

return extendedCtx, nil
}

// handleStart will handle s start message.
func (h *Handler) handleStart(id string, payload []byte) {
func (h *Handler) handleStart(ctx context.Context, id string, payload []byte) {
executor, err := h.executorPool.Get(payload)
if err != nil {
h.logger.Error("subscription.Handler.handleStart()",
Expand All @@ -197,7 +231,7 @@ func (h *Handler) handleStart(id string, payload []byte) {
}

if executor.OperationType() == ast.OperationTypeSubscription {
ctx, subsErr := h.subCancellations.Add(id)
ctx, subsErr := h.subCancellations.AddWithParent(id, ctx)
if subsErr != nil {
h.handleError(id, graphql.RequestErrorsFromError(subsErr))
return
Expand All @@ -206,7 +240,7 @@ func (h *Handler) handleStart(id string, payload []byte) {
return
}

go h.handleNonSubscriptionOperation(id, executor)
go h.handleNonSubscriptionOperation(ctx, id, executor)
}

func (h *Handler) handleOnBeforeStart(executor Executor) error {
Expand All @@ -223,7 +257,7 @@ func (h *Handler) handleOnBeforeStart(executor Executor) error {
}

// handleNonSubscriptionOperation will handle a non-subscription operation like a query or a mutation.
func (h *Handler) handleNonSubscriptionOperation(id string, executor Executor) {
func (h *Handler) handleNonSubscriptionOperation(ctx context.Context, id string, executor Executor) {
defer func() {
err := h.executorPool.Put(executor)
if err != nil {
Expand All @@ -233,6 +267,7 @@ func (h *Handler) handleNonSubscriptionOperation(id string, executor Executor) {
}
}()

executor.SetContext(ctx)
buf := h.bufferPool.Get().(*graphql.EngineResultWriter)
buf.Reset()

Expand Down Expand Up @@ -392,6 +427,35 @@ func (h *Handler) sendKeepAlive() {
}
}

func (h *Handler) terminateConnection(reason interface{}) {
payloadBytes, err := json.Marshal(reason)
if err != nil {
h.logger.Error("subscription.Handler.terminateConnection()",
abstractlogger.Error(err),
abstractlogger.Any("errorPayload", reason),
)
}

connectionErrorMessage := Message{
Type: MessageTypeConnectionTerminate,
Payload: payloadBytes,
}

err = h.client.WriteToClient(connectionErrorMessage)
if err != nil {
h.logger.Error("subscription.Handler.terminateConnection()",
abstractlogger.Error(err),
)

err := h.client.Disconnect()
if err != nil {
h.logger.Error("subscription.Handler.terminateConnection()",
abstractlogger.Error(err),
)
}
}
}

// handleConnectionError will handle a connection error message.
func (h *Handler) handleConnectionError(errorPayload interface{}) {
payloadBytes, err := json.Marshal(errorPayload)
Expand Down
70 changes: 65 additions & 5 deletions pkg/subscription/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -296,8 +297,20 @@ func TestHandler_Handle(t *testing.T) {
defer chatServer.Close()

t.Run("connection_init", func(t *testing.T) {
var initPayloadAuthorization string

executorPool, _ := setupEngineV2(t, ctx, chatServer.URL)
_, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool)
_, client, handlerRoutine := setupSubscriptionHandlerWithInitFuncTest(t, executorPool, func(ctx context.Context, initPayload InitPayload) (context.Context, error) {
if initPayloadAuthorization == "" {
return ctx, nil
}

if initPayloadAuthorization != initPayload.Authorization() {
return nil, fmt.Errorf("unknown user: %s", initPayload.Authorization())
}

return ctx, nil
})

t.Run("should send connection error message when error on read occurrs", func(t *testing.T) {
client.prepareConnectionInitMessage().withError().and().send()
Expand Down Expand Up @@ -331,6 +344,45 @@ func TestHandler_Handle(t *testing.T) {
messagesFromServer := client.readFromServer()
assert.Contains(t, messagesFromServer, expectedMessage)
})

t.Run("should send connection error message when error on check initial payload occurrs", func(t *testing.T) {
initPayloadAuthorization = "123"
defer func() { initPayloadAuthorization = "" }()

client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "111"}`)).withoutError().and().send()

ctx, cancelFunc := context.WithCancel(context.Background())

cancelFunc()
require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond)

expectedMessage := Message{
Type: MessageTypeConnectionTerminate,
Payload: jsonizePayload(t, "failed to accept the websocket connection"),
}

messagesFromServer := client.readFromServer()
assert.Contains(t, messagesFromServer, expectedMessage)
})

t.Run("should successfully init connection and respond with ack when initial payload successfully occurred ", func(t *testing.T) {
initPayloadAuthorization = "123"
defer func() { initPayloadAuthorization = "" }()

client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "123"}`)).withoutError().and().send()

ctx, cancelFunc := context.WithCancel(context.Background())

cancelFunc()
require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond)

expectedMessage := Message{
Type: MessageTypeConnectionAck,
}

messagesFromServer := client.readFromServer()
assert.Contains(t, messagesFromServer, expectedMessage)
})
})

t.Run("connection_keep_alive", func(t *testing.T) {
Expand Down Expand Up @@ -446,7 +498,7 @@ func TestHandler_Handle(t *testing.T) {
client.prepareStartMessage("1", payload).withoutError().and().send()

ctx, cancelFunc := context.WithCancel(context.Background())
cancelFunc()
defer cancelFunc()
handlerRoutineFunc := handlerRoutine(ctx)
go handlerRoutineFunc()

Expand Down Expand Up @@ -532,7 +584,7 @@ func TestHandler_Handle(t *testing.T) {
go handlerRoutineFunc()

time.Sleep(10 * time.Millisecond)
cancelFunc()
defer cancelFunc()

go sendChatMutation(t, chatServer.URL)

Expand Down Expand Up @@ -562,7 +614,7 @@ func TestHandler_Handle(t *testing.T) {
go handlerRoutineFunc()

time.Sleep(10 * time.Millisecond)
cancelFunc()
defer cancelFunc()

go sendChatMutation(t, chatServer.URL)

Expand Down Expand Up @@ -819,10 +871,18 @@ func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*Ex
}

func setupSubscriptionHandlerTest(t *testing.T, executorPool ExecutorPool) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) {
return setupSubscriptionHandlerWithInitFuncTest(t, executorPool, nil)
}

func setupSubscriptionHandlerWithInitFuncTest(
t *testing.T,
executorPool ExecutorPool,
initFunc WebsocketInitFunc,
) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) {
client = newMockClient()

var err error
subscriptionHandler, err = NewHandler(abstractlogger.NoopLogger, client, executorPool)
subscriptionHandler, err = NewHandlerWithInitFunc(abstractlogger.NoopLogger, client, executorPool, initFunc)
require.NoError(t, err)

routine = func(ctx context.Context) func() bool {
Expand Down
Loading

0 comments on commit bfd88c5

Please sign in to comment.