diff --git a/.golangci.yml b/.golangci.yml index 3c294267..6e65e014 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -49,7 +49,7 @@ linters-settings: - golang.org/x/tools - gopkg.in/yaml.v2 - github.com/alexflint/go-arg - - github.com/gorilla/websocket + - github.com/google/uuid forbidigo: forbid: diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 497fdc66..da5e604b 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -26,7 +26,8 @@ When releasing a new version: - The new `optional: generic` allows using a generic type to represent optionality. See the [documentation](genqlient.yaml) for details. - For schemas with enum values that differ only in casing, it's now possible to disable smart-casing in genqlient.yaml; see the [documentation](genqlient.yaml) for `casing` for details. -- genqlient now supports subscriptions +- genqlient now supports subscriptions; the websocket protocol is by default `graphql-transport-ws` but can be set to another value. + See the [documentation](FAQ.md) for how to `subscribe to an API 'subscription' endpoint`. ### Bug fixes: - The presence of negative pointer directives, i.e., `# @genqlient(pointer: false)` are now respected even in the when `optional: pointer` is set in the configuration file. diff --git a/docs/FAQ.md b/docs/FAQ.md index b7b49226..66dde308 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -108,13 +108,18 @@ Once your webSocket client matches the interfaces, you can get your `graphql.Web a loop for incoming messages and errors: ```go -graphqlClient := graphql.NewClientUsingWebSocket( + graphqlClient := graphql.NewClientUsingWebSocket( "ws://localhost:8080/query", &MyDialer{Dialer: dialer}, headers, ) - respChan, errChan, err := count(context.Background(), graphqlClient) + errChan, err := graphqlClient.StartWebSocket(ctx) + if err != nil { + return + } + + dataChan, subscriptionID, err := count(ctx, graphqlClient) if err != nil { return } @@ -122,7 +127,7 @@ graphqlClient := graphql.NewClientUsingWebSocket( defer graphqlClient.CloseWebSocket() for loop := true; loop; { select { - case msg, more := <-respChan: + case msg, more := <-dataChan: if !more { loop = false break @@ -136,9 +141,16 @@ graphqlClient := graphql.NewClientUsingWebSocket( } case err = <-errChan: return + case <-time.After(time.Minute): + err = wsClient.Unsubscribe(subscriptionID) + loop = false } } ``` +To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `graphql.NewClientUsingWebSocket()`: +```go + headers.Add("Sec-WebSocket-Protocol", "graphql-ws") +``` ### … use an API that requires authentication? diff --git a/generate/generate.go b/generate/generate.go index 07661ca9..be26301a 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -282,7 +282,7 @@ func (g *generator) addOperation(op *ast.OperationDefinition) error { docComment = "// " + strings.ReplaceAll(commentLines, "\n", "\n// ") } if op.Operation == ast.Subscription { - docComment += "\n// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()]" + docComment += "\n// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe]" } // If the filename is a pseudo-filename filename.go:startline, just diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 1d32c88c..1b70890e 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -15,7 +15,7 @@ func {{.Name}}( {{.GraphQLName}} {{.GoType.Reference}}, {{end -}} {{end -}} -) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, errChan_ chan error,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { +) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { req_ := &graphql.Request{ OpName: "{{.Name}}", Query: {{.Name}}_Operation, @@ -36,14 +36,8 @@ func {{.Name}}( } {{end}} {{if eq .Type "subscription"}} - dataChan_ = make(chan {{.Name}}WsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) - - errChan_, err_ = client_.DialWebSocket({{if ne .Config.ContextType "-" -}}ctx_{{else}}context.Background(){{end}}, req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go {{.Name}}ForwardData(dataChan_, respChan_, errChan_) + dataChan_ = make(chan {{.Name}}WsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData) {{else}} data_ = &{{.ResponseName}}{} resp_ := &graphql.Response{Data: data_} @@ -55,7 +49,7 @@ func {{.Name}}( ) {{end}} - return {{if eq .Type "subscription"}}dataChan_, errChan_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_ + return {{if eq .Type "subscription"}}dataChan_, subscriptionID_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_ } {{if eq .Type "subscription"}} @@ -65,30 +59,26 @@ type {{.Name}}WsResponse struct { Errors error `json:"errors"` } -func {{.Name}}ForwardData(dataChan_ chan {{.Name}}WsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp {{.Name}}WsResponse - for { - jsonRaw, more_ := <-respChan_ - if !more_ { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) - if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan {{.Name}}WsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") } + dataChan_ <- wsResp + return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index e236254e..4f1ebca6 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -3,8 +3,8 @@ package test import ( - "context" "encoding/json" + "errors" "github.com/Khan/genqlient/graphql" ) @@ -24,25 +24,19 @@ subscription SimpleSubscription { } ` -// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()] +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func SimpleSubscription( client_ graphql.WebSocketClient, -) (dataChan_ chan SimpleSubscriptionWsResponse, errChan_ chan error, err_ error) { +) (dataChan_ chan SimpleSubscriptionWsResponse, subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "SimpleSubscription", Query: SimpleSubscription_Operation, } - dataChan_ = make(chan SimpleSubscriptionWsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) + dataChan_ = make(chan SimpleSubscriptionWsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, SimpleSubscriptionForwardData) - errChan_, err_ = client_.DialWebSocket(context.Background(), req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go SimpleSubscriptionForwardData(dataChan_, respChan_, errChan_) - - return dataChan_, errChan_, err_ + return dataChan_, subscriptionID_, err_ } type SimpleSubscriptionWsResponse struct { @@ -51,30 +45,26 @@ type SimpleSubscriptionWsResponse struct { Errors error `json:"errors"` } -func SimpleSubscriptionForwardData(dataChan_ chan SimpleSubscriptionWsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse - for { - jsonRaw, more_ := <-respChan_ - if !more_ { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors + return err } - dataChan_ <- wsResp + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan SimpleSubscriptionWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") } + dataChan_ <- wsResp + return nil } diff --git a/go.mod b/go.mod index d6967948..6d3f2050 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/99designs/gqlgen v0.17.35 github.com/alexflint/go-arg v1.4.2 github.com/bradleyjkemp/cupaloy/v2 v2.6.0 + github.com/google/uuid v1.3.1 github.com/gorilla/websocket v1.5.0 github.com/stretchr/testify v1.8.2 github.com/vektah/gqlparser/v2 v2.5.8 diff --git a/go.sum b/go.sum index 9388a4f5..885c5957 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru/v2 v2.0.3 h1:kmRrRLlInXvng0SmLxmQpQkpbYAvcXm7NPDrgxJa9mE= diff --git a/graphql/client.go b/graphql/client.go index b691ac14..ee3fd743 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -10,7 +10,9 @@ import ( "net/http" "net/url" "strings" + "sync" + "github.com/google/uuid" "github.com/vektah/gqlparser/v2/gqlerror" ) @@ -41,32 +43,38 @@ type Client interface { } type WebSocketClient interface { - // DialWebSocket must open a webSocket connection and subscribe to an endpoint + // StartWebSocket must open a webSocket connection and subscribe to an endpoint // of the client's GraphQL API. // - // ctx is the context that should be used to make this request. If context - // is disabled in the genqlient settings, this will be set to - // context.Background(). + // errChan is a channel on which are sent the errors of webSocket + // communication. + // + // err is any error that occurs when setting up the webSocket connection. + StartWebSocket(ctx context.Context) (errChan chan error, err error) + + // CloseWebSocket must close the webSocket connection. + CloseWebSocket() + + // Subscribe must subscribe to an endpoint of the client's GraphQL API. // // req contains the data to be sent to the GraphQL server. Will be marshalled // into JSON bytes. // - // respChan is a channel used to send the data that arrives via the + // interfaceChan is a channel used to send the data that arrives via the // webSocket connection. // - // errChan is a channel on which are sent the errors of webSocket - // communication. + // forwardDataFunc is the function that will cast the received interface into + // the valid type for the subscription's response. // - // err is any error that occurs when setting up the webSocket connection. - DialWebSocket( - ctx context.Context, + // returns a subscriptionID if successful, an error otherwise + Subscribe( req *Request, - respChan chan json.RawMessage, - ) (errChan chan error, err error) + interfaceChan interface{}, + forwardDataFunc ForwardDataFunction, + ) (string, error) - // CloseWebSocket must end the graphql subscription and close the webSocket - // connection. - CloseWebSocket() + // Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API. + Unsubscribe(subscriptionID string) error } type client struct { @@ -76,11 +84,13 @@ type client struct { } type webSocketClient struct { - Dialer Dialer - Header http.Header - conn WSConn - errChan chan error - endpoint string + Dialer Dialer + Header http.Header + conn WSConn + errChan chan error + endpoint string + subscriptions subscriptionMap + sync.RWMutex } // NewClient returns a [Client] which makes requests to the given endpoint, @@ -132,12 +142,15 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head if headers == nil { headers = http.Header{} } - headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") + if headers.Get("Sec-WebSocket-Protocol") == "" { + headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") + } return &webSocketClient{ - Dialer: wsDialer, - Header: headers, - endpoint: endpoint, - errChan: make(chan error, 1), + Dialer: wsDialer, + Header: headers, + errChan: make(chan error), + endpoint: endpoint, + subscriptions: subscriptionMap{map_: make(map[string]subscription)}, } } @@ -249,66 +262,6 @@ func (c *client) MakeRequest(ctx context.Context, req *Request, resp *Response) return nil } -func (w *webSocketClient) DialWebSocket(ctx context.Context, req *Request, respChan chan json.RawMessage) (errChan chan error, err error) { - if req.Query != "" { - if strings.HasPrefix(strings.TrimSpace(req.Query), "query") { - return nil, errors.New("client does not support queries") - } - if strings.HasPrefix(strings.TrimSpace(req.Query), "mutation") { - return nil, errors.New("client does not support mutations") - } - } - - err = w.subscribeAndListen( - ctx, - req, - respChan, - ) - - return w.errChan, err -} - -func (w *webSocketClient) CloseWebSocket() { - defer w.conn.Close() - err := w.sendComplete() - if err != nil { - w.errChan <- err - } - err = w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) - if err != nil { - w.errChan <- err - } -} - -func (w *webSocketClient) subscribeAndListen(ctx context.Context, req *Request, respChan chan json.RawMessage) error { - var err error - w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) - if err != nil { - return err - } - - err = w.sendInit() - if err != nil { - w.conn.Close() - return err - } - err = w.waitForConnAck() - if err != nil { - w.conn.Close() - return err - } - - go w.listenWebSocket(respChan) - - err = w.sendSubscribe(req) - if err != nil { - w.conn.Close() - return err - } - - return nil -} - func (c *client) createPostRequest(req *Request) (*http.Request, error) { if req.Query != "" { if strings.HasPrefix(strings.TrimSpace(req.Query), "subscription") { @@ -380,3 +333,67 @@ func (c *client) createGetRequest(req *Request) (*http.Request, error) { return httpReq, nil } + +func (w *webSocketClient) StartWebSocket(ctx context.Context) (errChan chan error, err error) { + w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) + if err != nil { + return nil, err + } + err = w.sendInit() + if err != nil { + w.conn.Close() + return nil, err + } + err = w.waitForConnAck() + if err != nil { + w.conn.Close() + return nil, err + } + go w.listenWebSocket(ctx) + return w.errChan, err +} + +func (w *webSocketClient) CloseWebSocket() { + defer w.conn.Close() + err := w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) + if err != nil { + w.errChan <- err + } +} + +func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) (string, error) { + if req.Query != "" { + if strings.HasPrefix(strings.TrimSpace(req.Query), "query") { + return "", errors.New("client does not support queries") + } + if strings.HasPrefix(strings.TrimSpace(req.Query), "mutation") { + return "", errors.New("client does not support mutations") + } + } + + subscriptionID := uuid.NewString() + subscriptionMsg := webSocketSendMessage{ + Type: webSocketTypeSubscribe, + Payload: req, + ID: subscriptionID, + } + err := w.sendStructAsJSON(subscriptionMsg) + if err != nil { + return "", err + } + w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc) + return subscriptionID, nil +} + +func (w *webSocketClient) Unsubscribe(subscriptionID string) error { + completeMsg := webSocketSendMessage{ + Type: webSocketTypeComplete, + ID: subscriptionID, + } + err := w.sendStructAsJSON(completeMsg) + if err != nil { + return err + } + w.subscriptions.Delete(subscriptionID) + return nil +} diff --git a/graphql/subscription.go b/graphql/subscription.go new file mode 100644 index 00000000..79c765fd --- /dev/null +++ b/graphql/subscription.go @@ -0,0 +1,43 @@ +package graphql + +import ( + "encoding/json" + "sync" +) + +// map of subscription ID to subscription +type subscriptionMap struct { + map_ map[string]subscription + sync.RWMutex +} + +type subscription struct { + interfaceChan interface{} + forwardDataFunc ForwardDataFunction + id string +} + +type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error + +func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { + s.Lock() + defer s.Unlock() + s.map_[subscriptionID] = subscription{ + id: subscriptionID, + interfaceChan: interfaceChan, + forwardDataFunc: forwardDataFunc, + } +} + +func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) { + s.RLock() + defer s.RUnlock() + sub, success = s.map_[subscriptionID] + return sub, success +} + +func (s *subscriptionMap) Delete(subscriptionID string) { + s.Lock() + defer s.Unlock() + delete(s.map_, subscriptionID) +} diff --git a/graphql/websocket.go b/graphql/websocket.go index 646dbc39..4b87593b 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -1,6 +1,7 @@ package graphql import ( + "context" "encoding/binary" "encoding/json" "fmt" @@ -39,33 +40,20 @@ const ( type webSocketSendMessage struct { Payload *Request `json:"payload"` Type string `json:"type"` + ID string `json:"id"` } type webSocketReceiveMessage struct { Type string `json:"type"` + ID string `json:"id"` Payload json.RawMessage `json:"payload"` } func (w *webSocketClient) sendInit() error { - connInit := webSocketSendMessage{ + connInitMsg := webSocketSendMessage{ Type: webSocketTypeConnInit, } - return w.sendStructAsJSON(connInit) -} - -func (w *webSocketClient) sendSubscribe(req *Request) error { - subscription := webSocketSendMessage{ - Type: webSocketTypeSubscribe, - Payload: req, - } - return w.sendStructAsJSON(subscription) -} - -func (w *webSocketClient) sendComplete() error { - complete := webSocketSendMessage{ - Type: webSocketTypeComplete, - } - return w.sendStructAsJSON(complete) + return w.sendStructAsJSON(connInitMsg) } func (w *webSocketClient) sendStructAsJSON(object any) error { @@ -92,22 +80,40 @@ func (w *webSocketClient) waitForConnAck() error { return nil } -func (w *webSocketClient) listenWebSocket(respChan chan json.RawMessage) { - defer close(respChan) +func (w *webSocketClient) listenWebSocket(ctx context.Context) { for { - _, message, err := w.conn.ReadMessage() - if err != nil { - w.errChan <- err - return - } - err = forwardWebSocketData(respChan, message) - if err != nil { - w.errChan <- err + select { + case <-ctx.Done(): + w.errChan <- fmt.Errorf("context canceled") return + default: + _, message, err := w.conn.ReadMessage() + if err != nil { + w.errChan <- err + return + } + err = w.forwardWebSocketData(message) + if err != nil { + w.errChan <- err + return + } } } } +func (w *webSocketClient) forwardWebSocketData(message []byte) error { + var wsMsg webSocketReceiveMessage + err := json.Unmarshal(message, &wsMsg) + if err != nil { + return err + } + sub, ok := w.subscriptions.Read(wsMsg.ID) + if !ok { + return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) + } + return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) +} + func (w *webSocketClient) receiveWebSocketConnAck() (bool, error) { _, message, err := w.conn.ReadMessage() if err != nil { @@ -125,20 +131,6 @@ func checkConnectionAckReceived(message []byte) (bool, error) { return wsMessage.Type == webSocketTypeConnAck, nil } -func forwardWebSocketData(respChan chan json.RawMessage, message []byte) error { - var wsMsg webSocketReceiveMessage - err := json.Unmarshal(message, &wsMsg) - if err != nil { - return err - } - switch wsMsg.Type { - case webSocketTypeNext, webSocketTypeError: - respChan <- wsMsg.Payload - default: - } - return nil -} - // formatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func formatCloseMessage(closeCode int, text string) []byte { diff --git a/internal/integration/generated.go b/internal/integration/generated.go index fbf71bc6..371fd1d3 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -5,6 +5,7 @@ package integration import ( "context" "encoding/json" + "errors" "fmt" "time" @@ -3093,26 +3094,20 @@ subscription count { } ` -// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()] +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func count( ctx_ context.Context, client_ graphql.WebSocketClient, -) (dataChan_ chan countWsResponse, errChan_ chan error, err_ error) { +) (dataChan_ chan countWsResponse, subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "count", Query: count_Operation, } - dataChan_ = make(chan countWsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) + dataChan_ = make(chan countWsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countForwardData) - errChan_, err_ = client_.DialWebSocket(ctx_, req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go countForwardData(dataChan_, respChan_, errChan_) - - return dataChan_, errChan_, err_ + return dataChan_, subscriptionID_, err_ } type countWsResponse struct { @@ -3121,31 +3116,27 @@ type countWsResponse struct { Errors error `json:"errors"` } -func countForwardData(dataChan_ chan countWsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp countWsResponse - for { - jsonRaw, more_ := <-respChan_ - if !more_ { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) if err != nil { - errChan_ <- err - return + return err } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp + } else { + wsResp.Errors = gqlResp.Errors } + dataChan_, ok := interfaceChan.(chan countWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan countWsResponse'") + } + dataChan_ <- wsResp + return nil } // The mutation executed by createUser. diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 0ed4520d..53cfc58d 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -66,14 +66,17 @@ func TestSubscription(t *testing.T) { defer server.Close() wsClient := newRoundtripWebScoketClient(t, server.URL) - start := time.Now() - respChan, errChan, err := count(ctx, wsClient) + errChan, err := wsClient.StartWebSocket(ctx) + require.NoError(t, err) + + dataChan, subscriptionID, err := count(ctx, wsClient) require.NoError(t, err) defer wsClient.CloseWebSocket() counter := 0 + start := time.Now() for loop := true; loop; { select { - case resp, more := <-respChan: + case resp, more := <-dataChan: if !more { loop = false break @@ -81,11 +84,15 @@ func TestSubscription(t *testing.T) { require.NotNil(t, resp.Data) assert.Equal(t, counter, resp.Data.Count) require.Nil(t, resp.Errors) - loop = time.Since(start) < time.Second*2 + if time.Since(start) > time.Second*5 { + err = wsClient.Unsubscribe(subscriptionID) + require.NoError(t, err) + loop = false + } counter++ case err := <-errChan: require.NoError(t, err) - case <-time.After(time.Second * 5): + case <-time.After(time.Second * 10): require.NoError(t, fmt.Errorf("subscription timed out")) } } diff --git a/internal/integration/roundtrip.go b/internal/integration/roundtrip.go index 8b386c84..665c52f9 100644 --- a/internal/integration/roundtrip.go +++ b/internal/integration/roundtrip.go @@ -106,14 +106,22 @@ func (c *roundtripClient) MakeRequest(ctx context.Context, req *graphql.Request, return nil } -func (c *roundtripClient) DialWebSocket(ctx context.Context, req *graphql.Request, respChan chan json.RawMessage) (errChan chan error, err error) { - return c.wsWrapped.DialWebSocket(ctx, req, respChan) +func (c *roundtripClient) StartWebSocket(ctx context.Context) (errChan chan error, err error) { + return c.wsWrapped.StartWebSocket(ctx) } func (c *roundtripClient) CloseWebSocket() { c.wsWrapped.CloseWebSocket() } +func (c *roundtripClient) Subscribe(req *graphql.Request, interfaceChan interface{}, forwardDataFunc graphql.ForwardDataFunction) (string, error) { + return c.wsWrapped.Subscribe(req, interfaceChan, forwardDataFunc) +} + +func (c *roundtripClient) Unsubscribe(subscriptionID string) error { + return c.wsWrapped.Unsubscribe(subscriptionID) +} + func newRoundtripClients(t *testing.T, endpoint string) []graphql.Client { return []graphql.Client{newRoundtripClient(t, endpoint), newRoundtripGetClient(t, endpoint)} }