Skip to content

Commit

Permalink
Enable Subscription Resolver to return websocket error message (#2506)
Browse files Browse the repository at this point in the history
* Enanble Subscription Resolver to return websocket error message

* add PR link

* lint

* fmt and regenerate

Signed-off-by: Steve Coffman <steve@khanacademy.org>

Signed-off-by: Steve Coffman <steve@khanacademy.org>
Co-authored-by: Zhixin Wen <zwen@nuro.ai>
Co-authored-by: Steve Coffman <steve@khanacademy.org>
  • Loading branch information
3 people committed Jan 13, 2023
1 parent 2bd7cfe commit 11c3a4d
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 8 deletions.
7 changes: 4 additions & 3 deletions _examples/chat/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package chat

import (
"fmt"
"runtime"
"sync"
"testing"

"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"runtime"
"sync"
"testing"
)

func TestChatSubscriptions(t *testing.T) {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
c.mu.Unlock()

go func() {
ctx = withSubscriptionErrorContext(ctx)
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
Expand All @@ -362,7 +363,11 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
}
c.sendError(msg.id, gqlerr)
}
c.complete(msg.id)
if errs := getSubscriptionError(ctx); len(errs) != 0 {
c.sendError(msg.id, errs...)
} else {
c.complete(msg.id)
}
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
Expand Down
69 changes: 69 additions & 0 deletions graphql/handler/transport/websocket_resolver_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package transport

import (
"context"

"github.com/vektah/gqlparser/v2/gqlerror"
)

// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var wsSubscriptionErrorCtxKey = &wsSubscriptionErrorContextKey{"subscription-error"}

type wsSubscriptionErrorContextKey struct {
name string
}

type subscriptionError struct {
errs []*gqlerror.Error
}

// AddSubscriptionError is used to let websocket return an error message after subscription resolver returns a channel.
// for example:
//
// func (r *subscriptionResolver) Method(ctx context.Context) (<-chan *model.Message, error) {
// ch := make(chan *model.Message)
// go func() {
// defer func() {
// close(ch)
// }
// // some kind of block processing (e.g.: gRPC client streaming)
// stream, err := gRPCClientStreamRequest(ctx)
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// for {
// m, err := stream.Recv()
// if err == io.EOF {
// return
// }
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// ch <- m
// }
// }()
//
// return ch, nil
// }
//
// see https://github.com/99designs/gqlgen/pull/2506 for more details
func AddSubscriptionError(ctx context.Context, err *gqlerror.Error) {
subscriptionErrStruct := getSubscriptionErrorStruct(ctx)
subscriptionErrStruct.errs = append(subscriptionErrStruct.errs, err)
}

func withSubscriptionErrorContext(ctx context.Context) context.Context {
return context.WithValue(ctx, wsSubscriptionErrorCtxKey, &subscriptionError{})
}

func getSubscriptionErrorStruct(ctx context.Context) *subscriptionError {
v, _ := ctx.Value(wsSubscriptionErrorCtxKey).(*subscriptionError)
return v
}

func getSubscriptionError(ctx context.Context) []*gqlerror.Error {
return getSubscriptionErrorStruct(ctx).errs
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package generated
import "errors"

// Errors defined for retained code that we want to stick around between generations.
//
var (
ErrResolvingHelloWithErrorsByName = errors.New("error resolving HelloWithErrorsByName")
ErrEmptyKeyResolvingHelloWithErrorsByName = errors.New("error (empty key) resolving HelloWithErrorsByName")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package return_values

import (
"github.com/stretchr/testify/require"
"reflect"
"testing"

"github.com/stretchr/testify/require"
)

//go:generate rm -f resolvers.go
Expand Down

0 comments on commit 11c3a4d

Please sign in to comment.