From 049f7f1880e12048821eb98fd5c3020504ce3b56 Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 09:28:17 +0200 Subject: [PATCH 1/9] removed package gorilla/websocket from lint file --- .golangci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 3c294267..7f96025c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -49,7 +49,6 @@ linters-settings: - golang.org/x/tools - gopkg.in/yaml.v2 - github.com/alexflint/go-arg - - github.com/gorilla/websocket forbidigo: forbid: From 5b6f4bf31312f90b1d073f37f1f4739230992adf Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 09:33:58 +0200 Subject: [PATCH 2/9] added link to doc --- docs/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 497fdc66..d22d07fa 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -26,7 +26,7 @@ When releasing a new version: - The new `optional: generic` allows using a generic type to represent optionality. See the [documentation](genqlient.yaml) for details. - For schemas with enum values that differ only in casing, it's now possible to disable smart-casing in genqlient.yaml; see the [documentation](genqlient.yaml) for `casing` for details. -- genqlient now supports subscriptions +- genqlient now supports subscriptions; see the [documentation](FAQ.md) for how to `subscribe to an API 'subscription' endpoint` ### Bug fixes: - The presence of negative pointer directives, i.e., `# @genqlient(pointer: false)` are now respected even in the when `optional: pointer` is set in the configuration file. From f5d2e926cb756ca97c88b21e89846a4a9488f6bd Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 09:36:31 +0200 Subject: [PATCH 3/9] remove useless underscore --- generate/operation.go.tmpl | 4 ++-- ...e-SimpleSubscription.graphql-SimpleSubscription.graphql.go | 4 ++-- internal/integration/generated.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 1d32c88c..94ad55b5 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -70,8 +70,8 @@ func {{.Name}}ForwardData(dataChan_ chan {{.Name}}WsResponse, respChan_ chan jso var gqlResp graphql.Response var wsResp {{.Name}}WsResponse for { - jsonRaw, more_ := <-respChan_ - if !more_ { + jsonRaw, more := <-respChan_ + if !more { return } err := json.Unmarshal(jsonRaw, &gqlResp) diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index e236254e..8c117496 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -56,8 +56,8 @@ func SimpleSubscriptionForwardData(dataChan_ chan SimpleSubscriptionWsResponse, var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse for { - jsonRaw, more_ := <-respChan_ - if !more_ { + jsonRaw, more := <-respChan_ + if !more { return } err := json.Unmarshal(jsonRaw, &gqlResp) diff --git a/internal/integration/generated.go b/internal/integration/generated.go index fbf71bc6..30006c9c 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3126,8 +3126,8 @@ func countForwardData(dataChan_ chan countWsResponse, respChan_ chan json.RawMes var gqlResp graphql.Response var wsResp countWsResponse for { - jsonRaw, more_ := <-respChan_ - if !more_ { + jsonRaw, more := <-respChan_ + if !more { return } err := json.Unmarshal(jsonRaw, &gqlResp) From 70661393e47502b15fa506819cef4b0243dfc42f Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 09:38:53 +0200 Subject: [PATCH 4/9] bis --- generate/operation.go.tmpl | 4 ++-- ...e-SimpleSubscription.graphql-SimpleSubscription.graphql.go | 4 ++-- internal/integration/generated.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 94ad55b5..b56d2c5a 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -65,12 +65,12 @@ type {{.Name}}WsResponse struct { Errors error `json:"errors"` } -func {{.Name}}ForwardData(dataChan_ chan {{.Name}}WsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { +func {{.Name}}ForwardData(dataChan_ chan {{.Name}}WsResponse, respChan chan json.RawMessage, errChan_ chan error) { defer close(dataChan_) var gqlResp graphql.Response var wsResp {{.Name}}WsResponse for { - jsonRaw, more := <-respChan_ + jsonRaw, more := <-respChan if !more { return } diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 8c117496..1791ffec 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -51,12 +51,12 @@ type SimpleSubscriptionWsResponse struct { Errors error `json:"errors"` } -func SimpleSubscriptionForwardData(dataChan_ chan SimpleSubscriptionWsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { +func SimpleSubscriptionForwardData(dataChan_ chan SimpleSubscriptionWsResponse, respChan chan json.RawMessage, errChan_ chan error) { defer close(dataChan_) var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse for { - jsonRaw, more := <-respChan_ + jsonRaw, more := <-respChan if !more { return } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 30006c9c..ff62a46b 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3121,12 +3121,12 @@ type countWsResponse struct { Errors error `json:"errors"` } -func countForwardData(dataChan_ chan countWsResponse, respChan_ chan json.RawMessage, errChan_ chan error) { +func countForwardData(dataChan_ chan countWsResponse, respChan chan json.RawMessage, errChan_ chan error) { defer close(dataChan_) var gqlResp graphql.Response var wsResp countWsResponse for { - jsonRaw, more := <-respChan_ + jsonRaw, more := <-respChan if !more { return } From 28987e18a95dce933a4b467cb35fff0651948ee6 Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 09:55:41 +0200 Subject: [PATCH 5/9] possible to define other websocket protocol via header --- docs/CHANGELOG.md | 3 ++- docs/FAQ.md | 6 +++++- graphql/client.go | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index d22d07fa..da5e604b 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -26,7 +26,8 @@ When releasing a new version: - The new `optional: generic` allows using a generic type to represent optionality. See the [documentation](genqlient.yaml) for details. - For schemas with enum values that differ only in casing, it's now possible to disable smart-casing in genqlient.yaml; see the [documentation](genqlient.yaml) for `casing` for details. -- genqlient now supports subscriptions; see the [documentation](FAQ.md) for how to `subscribe to an API 'subscription' endpoint` +- genqlient now supports subscriptions; the websocket protocol is by default `graphql-transport-ws` but can be set to another value. + See the [documentation](FAQ.md) for how to `subscribe to an API 'subscription' endpoint`. ### Bug fixes: - The presence of negative pointer directives, i.e., `# @genqlient(pointer: false)` are now respected even in the when `optional: pointer` is set in the configuration file. diff --git a/docs/FAQ.md b/docs/FAQ.md index b7b49226..f1b49325 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -108,7 +108,7 @@ Once your webSocket client matches the interfaces, you can get your `graphql.Web a loop for incoming messages and errors: ```go -graphqlClient := graphql.NewClientUsingWebSocket( + graphqlClient := graphql.NewClientUsingWebSocket( "ws://localhost:8080/query", &MyDialer{Dialer: dialer}, headers, @@ -139,6 +139,10 @@ graphqlClient := graphql.NewClientUsingWebSocket( } } ``` +To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `graphql.NewClientUsingWebSocket()`: +```go + headers.Add("Sec-WebSocket-Protocol", "graphql-ws") +``` ### … use an API that requires authentication? diff --git a/graphql/client.go b/graphql/client.go index b691ac14..94dbe8b7 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -132,7 +132,9 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head if headers == nil { headers = http.Header{} } - headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") + if headers.Get("Sec-WebSocket-Protocol") == "" { + headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") + } return &webSocketClient{ Dialer: wsDialer, Header: headers, From 5487b0bdca5d53b0f5c1d23fb9e0a9f498fca729 Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 11:37:09 +0200 Subject: [PATCH 6/9] removed buffer from channels --- generate/operation.go.tmpl | 4 ++-- ...e-SimpleSubscription.graphql-SimpleSubscription.graphql.go | 4 ++-- graphql/client.go | 2 +- internal/integration/generated.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index b56d2c5a..0edd3133 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -36,8 +36,8 @@ func {{.Name}}( } {{end}} {{if eq .Type "subscription"}} - dataChan_ = make(chan {{.Name}}WsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) + dataChan_ = make(chan {{.Name}}WsResponse) + respChan_ := make(chan json.RawMessage) errChan_, err_ = client_.DialWebSocket({{if ne .Config.ContextType "-" -}}ctx_{{else}}context.Background(){{end}}, req_, respChan_) if err_ != nil { diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 1791ffec..1857fe9e 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -33,8 +33,8 @@ func SimpleSubscription( Query: SimpleSubscription_Operation, } - dataChan_ = make(chan SimpleSubscriptionWsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) + dataChan_ = make(chan SimpleSubscriptionWsResponse) + respChan_ := make(chan json.RawMessage) errChan_, err_ = client_.DialWebSocket(context.Background(), req_, respChan_) if err_ != nil { diff --git a/graphql/client.go b/graphql/client.go index 94dbe8b7..c92cd026 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -139,7 +139,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head Dialer: wsDialer, Header: headers, endpoint: endpoint, - errChan: make(chan error, 1), + errChan: make(chan error), } } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index ff62a46b..4edafebf 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3103,8 +3103,8 @@ func count( Query: count_Operation, } - dataChan_ = make(chan countWsResponse, 1) - respChan_ := make(chan json.RawMessage, 1) + dataChan_ = make(chan countWsResponse) + respChan_ := make(chan json.RawMessage) errChan_, err_ = client_.DialWebSocket(ctx_, req_, respChan_) if err_ != nil { From cf8b2c680d06fcdf8a761fe1b3af748c9c36efea Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Fri, 6 Oct 2023 14:50:08 +0200 Subject: [PATCH 7/9] checking context --- generate/operation.go.tmpl | 50 ++++++++++++++++++------------- internal/integration/generated.go | 37 +++++++++++++---------- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 0edd3133..ca849244 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -43,7 +43,7 @@ func {{.Name}}( if err_ != nil { return nil, nil, err_ } - go {{.Name}}ForwardData(dataChan_, respChan_, errChan_) + go {{.Name}}ForwardData({{if ne .Config.ContextType "-" -}}ctx_, {{end}}dataChan_, respChan_, errChan_) {{else}} data_ = &{{.ResponseName}}{} resp_ := &graphql.Response{Data: data_} @@ -65,30 +65,38 @@ type {{.Name}}WsResponse struct { Errors error `json:"errors"` } -func {{.Name}}ForwardData(dataChan_ chan {{.Name}}WsResponse, respChan chan json.RawMessage, errChan_ chan error) { +func {{.Name}}ForwardData({{if ne .Config.ContextType "-" -}}ctx_ {{ref .Config.ContextType}}, {{end}}dataChan_ chan {{.Name}}WsResponse, respChan chan json.RawMessage, errChan_ chan error) { defer close(dataChan_) var gqlResp graphql.Response var wsResp {{.Name}}WsResponse for { - jsonRaw, more := <-respChan - if !more { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) - if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp + {{if ne .Config.ContextType "-" -}} + select { + case <-ctx_.Done(): + errChan_ <- errors.New("context was canceled") + return + default: + {{end -}} + jsonRaw, more := <-respChan + if !more { + return + } + err := json.Unmarshal(jsonRaw, &gqlResp) + if err != nil { + errChan_ <- err + return + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRaw, &wsResp) + if err != nil { + errChan_ <- err + return + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_ <- wsResp + {{if ne .Config.ContextType "-" -}}}{{end -}} } } {{end}} diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 4edafebf..4a55eb3d 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -5,6 +5,7 @@ package integration import ( "context" "encoding/json" + "errors" "fmt" "time" @@ -3110,7 +3111,7 @@ func count( if err_ != nil { return nil, nil, err_ } - go countForwardData(dataChan_, respChan_, errChan_) + go countForwardData(ctx_, dataChan_, respChan_, errChan_) return dataChan_, errChan_, err_ } @@ -3121,30 +3122,36 @@ type countWsResponse struct { Errors error `json:"errors"` } -func countForwardData(dataChan_ chan countWsResponse, respChan chan json.RawMessage, errChan_ chan error) { +func countForwardData(ctx_ context.Context, dataChan_ chan countWsResponse, respChan chan json.RawMessage, errChan_ chan error) { defer close(dataChan_) var gqlResp graphql.Response var wsResp countWsResponse for { - jsonRaw, more := <-respChan - if !more { + select { + case <-ctx_.Done(): + errChan_ <- errors.New("context was canceled") return - } - err := json.Unmarshal(jsonRaw, &gqlResp) - if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) + default: + jsonRaw, more := <-respChan + if !more { + return + } + err := json.Unmarshal(jsonRaw, &gqlResp) if err != nil { errChan_ <- err return } - } else { - wsResp.Errors = gqlResp.Errors + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRaw, &wsResp) + if err != nil { + errChan_ <- err + return + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_ <- wsResp } - dataChan_ <- wsResp } } From 0781863043730e96d7b2e01573329b39b3511d7d Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Mon, 9 Oct 2023 17:19:00 +0200 Subject: [PATCH 8/9] only 1 goroutine; multiple subscriptions per connection --- .golangci.yml | 1 + docs/FAQ.md | 12 +- generate/generate.go | 2 +- generate/operation.go.tmpl | 60 ++---- ...tion.graphql-SimpleSubscription.graphql.go | 52 ++--- go.mod | 1 + go.sum | 2 + graphql/client.go | 183 ++++++++++-------- graphql/subscription.go | 43 ++++ graphql/websocket.go | 53 ++--- internal/integration/generated.go | 58 ++---- internal/integration/integration_test.go | 17 +- internal/integration/roundtrip.go | 12 +- 13 files changed, 261 insertions(+), 235 deletions(-) create mode 100644 graphql/subscription.go diff --git a/.golangci.yml b/.golangci.yml index 7f96025c..6e65e014 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -49,6 +49,7 @@ linters-settings: - golang.org/x/tools - gopkg.in/yaml.v2 - github.com/alexflint/go-arg + - github.com/google/uuid forbidigo: forbid: diff --git a/docs/FAQ.md b/docs/FAQ.md index f1b49325..66dde308 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -114,7 +114,12 @@ a loop for incoming messages and errors: headers, ) - respChan, errChan, err := count(context.Background(), graphqlClient) + errChan, err := graphqlClient.StartWebSocket(ctx) + if err != nil { + return + } + + dataChan, subscriptionID, err := count(ctx, graphqlClient) if err != nil { return } @@ -122,7 +127,7 @@ a loop for incoming messages and errors: defer graphqlClient.CloseWebSocket() for loop := true; loop; { select { - case msg, more := <-respChan: + case msg, more := <-dataChan: if !more { loop = false break @@ -136,6 +141,9 @@ a loop for incoming messages and errors: } case err = <-errChan: return + case <-time.After(time.Minute): + err = wsClient.Unsubscribe(subscriptionID) + loop = false } } ``` diff --git a/generate/generate.go b/generate/generate.go index 07661ca9..be26301a 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -282,7 +282,7 @@ func (g *generator) addOperation(op *ast.OperationDefinition) error { docComment = "// " + strings.ReplaceAll(commentLines, "\n", "\n// ") } if op.Operation == ast.Subscription { - docComment += "\n// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()]" + docComment += "\n// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe]" } // If the filename is a pseudo-filename filename.go:startline, just diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index ca849244..1b70890e 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -15,7 +15,7 @@ func {{.Name}}( {{.GraphQLName}} {{.GoType.Reference}}, {{end -}} {{end -}} -) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, errChan_ chan error,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { +) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { req_ := &graphql.Request{ OpName: "{{.Name}}", Query: {{.Name}}_Operation, @@ -37,13 +37,7 @@ func {{.Name}}( {{end}} {{if eq .Type "subscription"}} dataChan_ = make(chan {{.Name}}WsResponse) - respChan_ := make(chan json.RawMessage) - - errChan_, err_ = client_.DialWebSocket({{if ne .Config.ContextType "-" -}}ctx_{{else}}context.Background(){{end}}, req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go {{.Name}}ForwardData({{if ne .Config.ContextType "-" -}}ctx_, {{end}}dataChan_, respChan_, errChan_) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData) {{else}} data_ = &{{.ResponseName}}{} resp_ := &graphql.Response{Data: data_} @@ -55,7 +49,7 @@ func {{.Name}}( ) {{end}} - return {{if eq .Type "subscription"}}dataChan_, errChan_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_ + return {{if eq .Type "subscription"}}dataChan_, subscriptionID_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_ } {{if eq .Type "subscription"}} @@ -65,38 +59,26 @@ type {{.Name}}WsResponse struct { Errors error `json:"errors"` } -func {{.Name}}ForwardData({{if ne .Config.ContextType "-" -}}ctx_ {{ref .Config.ContextType}}, {{end}}dataChan_ chan {{.Name}}WsResponse, respChan chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp {{.Name}}WsResponse - for { - {{if ne .Config.ContextType "-" -}} - select { - case <-ctx_.Done(): - errChan_ <- errors.New("context was canceled") - return - default: - {{end -}} - jsonRaw, more := <-respChan - if !more { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) - if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp - {{if ne .Config.ContextType "-" -}}}{{end -}} + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan {{.Name}}WsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") } + dataChan_ <- wsResp + return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 1857fe9e..4f1ebca6 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -3,8 +3,8 @@ package test import ( - "context" "encoding/json" + "errors" "github.com/Khan/genqlient/graphql" ) @@ -24,25 +24,19 @@ subscription SimpleSubscription { } ` -// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()] +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func SimpleSubscription( client_ graphql.WebSocketClient, -) (dataChan_ chan SimpleSubscriptionWsResponse, errChan_ chan error, err_ error) { +) (dataChan_ chan SimpleSubscriptionWsResponse, subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "SimpleSubscription", Query: SimpleSubscription_Operation, } dataChan_ = make(chan SimpleSubscriptionWsResponse) - respChan_ := make(chan json.RawMessage) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, SimpleSubscriptionForwardData) - errChan_, err_ = client_.DialWebSocket(context.Background(), req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go SimpleSubscriptionForwardData(dataChan_, respChan_, errChan_) - - return dataChan_, errChan_, err_ + return dataChan_, subscriptionID_, err_ } type SimpleSubscriptionWsResponse struct { @@ -51,30 +45,26 @@ type SimpleSubscriptionWsResponse struct { Errors error `json:"errors"` } -func SimpleSubscriptionForwardData(dataChan_ chan SimpleSubscriptionWsResponse, respChan chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse - for { - jsonRaw, more := <-respChan - if !more { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors + return err } - dataChan_ <- wsResp + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan SimpleSubscriptionWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") } + dataChan_ <- wsResp + return nil } diff --git a/go.mod b/go.mod index d6967948..6d3f2050 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/99designs/gqlgen v0.17.35 github.com/alexflint/go-arg v1.4.2 github.com/bradleyjkemp/cupaloy/v2 v2.6.0 + github.com/google/uuid v1.3.1 github.com/gorilla/websocket v1.5.0 github.com/stretchr/testify v1.8.2 github.com/vektah/gqlparser/v2 v2.5.8 diff --git a/go.sum b/go.sum index 9388a4f5..885c5957 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru/v2 v2.0.3 h1:kmRrRLlInXvng0SmLxmQpQkpbYAvcXm7NPDrgxJa9mE= diff --git a/graphql/client.go b/graphql/client.go index c92cd026..34882b1f 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -10,7 +10,9 @@ import ( "net/http" "net/url" "strings" + "sync" + "github.com/google/uuid" "github.com/vektah/gqlparser/v2/gqlerror" ) @@ -41,32 +43,38 @@ type Client interface { } type WebSocketClient interface { - // DialWebSocket must open a webSocket connection and subscribe to an endpoint + // StartWebSocket must open a webSocket connection and subscribe to an endpoint // of the client's GraphQL API. // - // ctx is the context that should be used to make this request. If context - // is disabled in the genqlient settings, this will be set to - // context.Background(). + // errChan is a channel on which are sent the errors of webSocket + // communication. + // + // err is any error that occurs when setting up the webSocket connection. + StartWebSocket(ctx context.Context) (errChan chan error, err error) + + // CloseWebSocket must close the webSocket connection. + CloseWebSocket() + + // Subscribe must subscribe to an endpoint of the client's GraphQL API. // // req contains the data to be sent to the GraphQL server. Will be marshalled // into JSON bytes. // - // respChan is a channel used to send the data that arrives via the + // interfaceChan is a channel used to send the data that arrives via the // webSocket connection. // - // errChan is a channel on which are sent the errors of webSocket - // communication. + // forwardDataFunc is the function that will cast the received interface into + // the valid type for the subscription's response. // - // err is any error that occurs when setting up the webSocket connection. - DialWebSocket( - ctx context.Context, + // returns a subscriptionID if successful, an error otherwise + Subscribe( req *Request, - respChan chan json.RawMessage, - ) (errChan chan error, err error) + interfaceChan interface{}, + forwardDataFunc ForwardDataFunction, + ) (string, error) - // CloseWebSocket must end the graphql subscription and close the webSocket - // connection. - CloseWebSocket() + // Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API. + Unsubscribe(subscriptionID string) error } type client struct { @@ -76,11 +84,13 @@ type client struct { } type webSocketClient struct { - Dialer Dialer - Header http.Header - conn WSConn - errChan chan error - endpoint string + Dialer Dialer + Header http.Header + conn WSConn + errChan chan error + endpoint string + subscriptions subscriptionMap + sync.RWMutex } // NewClient returns a [Client] which makes requests to the given endpoint, @@ -136,10 +146,11 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") } return &webSocketClient{ - Dialer: wsDialer, - Header: headers, - endpoint: endpoint, - errChan: make(chan error), + Dialer: wsDialer, + Header: headers, + errChan: make(chan error), + endpoint: endpoint, + subscriptions: subscriptionMap{map_: make(map[string]subscription)}, } } @@ -251,66 +262,6 @@ func (c *client) MakeRequest(ctx context.Context, req *Request, resp *Response) return nil } -func (w *webSocketClient) DialWebSocket(ctx context.Context, req *Request, respChan chan json.RawMessage) (errChan chan error, err error) { - if req.Query != "" { - if strings.HasPrefix(strings.TrimSpace(req.Query), "query") { - return nil, errors.New("client does not support queries") - } - if strings.HasPrefix(strings.TrimSpace(req.Query), "mutation") { - return nil, errors.New("client does not support mutations") - } - } - - err = w.subscribeAndListen( - ctx, - req, - respChan, - ) - - return w.errChan, err -} - -func (w *webSocketClient) CloseWebSocket() { - defer w.conn.Close() - err := w.sendComplete() - if err != nil { - w.errChan <- err - } - err = w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) - if err != nil { - w.errChan <- err - } -} - -func (w *webSocketClient) subscribeAndListen(ctx context.Context, req *Request, respChan chan json.RawMessage) error { - var err error - w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) - if err != nil { - return err - } - - err = w.sendInit() - if err != nil { - w.conn.Close() - return err - } - err = w.waitForConnAck() - if err != nil { - w.conn.Close() - return err - } - - go w.listenWebSocket(respChan) - - err = w.sendSubscribe(req) - if err != nil { - w.conn.Close() - return err - } - - return nil -} - func (c *client) createPostRequest(req *Request) (*http.Request, error) { if req.Query != "" { if strings.HasPrefix(strings.TrimSpace(req.Query), "subscription") { @@ -382,3 +333,67 @@ func (c *client) createGetRequest(req *Request) (*http.Request, error) { return httpReq, nil } + +func (w *webSocketClient) StartWebSocket(ctx context.Context) (errChan chan error, err error) { + w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) + if err != nil { + return nil, err + } + err = w.sendInit() + if err != nil { + w.conn.Close() + return nil, err + } + err = w.waitForConnAck() + if err != nil { + w.conn.Close() + return nil, err + } + go w.listenWebSocket() + return w.errChan, err +} + +func (w *webSocketClient) CloseWebSocket() { + defer w.conn.Close() + err := w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) + if err != nil { + w.errChan <- err + } +} + +func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) (string, error) { + if req.Query != "" { + if strings.HasPrefix(strings.TrimSpace(req.Query), "query") { + return "", errors.New("client does not support queries") + } + if strings.HasPrefix(strings.TrimSpace(req.Query), "mutation") { + return "", errors.New("client does not support mutations") + } + } + + subscriptionID := uuid.NewString() + subscriptionMsg := webSocketSendMessage{ + Type: webSocketTypeSubscribe, + Payload: req, + ID: subscriptionID, + } + err := w.sendStructAsJSON(subscriptionMsg) + if err != nil { + return "", err + } + w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc) + return subscriptionID, nil +} + +func (w *webSocketClient) Unsubscribe(subscriptionID string) error { + completeMsg := webSocketSendMessage{ + Type: webSocketTypeComplete, + ID: subscriptionID, + } + err := w.sendStructAsJSON(completeMsg) + if err != nil { + return err + } + w.subscriptions.Delete(subscriptionID) + return nil +} diff --git a/graphql/subscription.go b/graphql/subscription.go new file mode 100644 index 00000000..79c765fd --- /dev/null +++ b/graphql/subscription.go @@ -0,0 +1,43 @@ +package graphql + +import ( + "encoding/json" + "sync" +) + +// map of subscription ID to subscription +type subscriptionMap struct { + map_ map[string]subscription + sync.RWMutex +} + +type subscription struct { + interfaceChan interface{} + forwardDataFunc ForwardDataFunction + id string +} + +type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error + +func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { + s.Lock() + defer s.Unlock() + s.map_[subscriptionID] = subscription{ + id: subscriptionID, + interfaceChan: interfaceChan, + forwardDataFunc: forwardDataFunc, + } +} + +func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) { + s.RLock() + defer s.RUnlock() + sub, success = s.map_[subscriptionID] + return sub, success +} + +func (s *subscriptionMap) Delete(subscriptionID string) { + s.Lock() + defer s.Unlock() + delete(s.map_, subscriptionID) +} diff --git a/graphql/websocket.go b/graphql/websocket.go index 646dbc39..2bd04ceb 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -39,33 +39,20 @@ const ( type webSocketSendMessage struct { Payload *Request `json:"payload"` Type string `json:"type"` + ID string `json:"id"` } type webSocketReceiveMessage struct { Type string `json:"type"` + ID string `json:"id"` Payload json.RawMessage `json:"payload"` } func (w *webSocketClient) sendInit() error { - connInit := webSocketSendMessage{ + connInitMsg := webSocketSendMessage{ Type: webSocketTypeConnInit, } - return w.sendStructAsJSON(connInit) -} - -func (w *webSocketClient) sendSubscribe(req *Request) error { - subscription := webSocketSendMessage{ - Type: webSocketTypeSubscribe, - Payload: req, - } - return w.sendStructAsJSON(subscription) -} - -func (w *webSocketClient) sendComplete() error { - complete := webSocketSendMessage{ - Type: webSocketTypeComplete, - } - return w.sendStructAsJSON(complete) + return w.sendStructAsJSON(connInitMsg) } func (w *webSocketClient) sendStructAsJSON(object any) error { @@ -92,15 +79,14 @@ func (w *webSocketClient) waitForConnAck() error { return nil } -func (w *webSocketClient) listenWebSocket(respChan chan json.RawMessage) { - defer close(respChan) +func (w *webSocketClient) listenWebSocket() { for { _, message, err := w.conn.ReadMessage() if err != nil { w.errChan <- err return } - err = forwardWebSocketData(respChan, message) + err = w.forwardWebSocketData(message) if err != nil { w.errChan <- err return @@ -108,6 +94,19 @@ func (w *webSocketClient) listenWebSocket(respChan chan json.RawMessage) { } } +func (w *webSocketClient) forwardWebSocketData(message []byte) error { + var wsMsg webSocketReceiveMessage + err := json.Unmarshal(message, &wsMsg) + if err != nil { + return err + } + sub, ok := w.subscriptions.Read(wsMsg.ID) + if !ok { + return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) + } + return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) +} + func (w *webSocketClient) receiveWebSocketConnAck() (bool, error) { _, message, err := w.conn.ReadMessage() if err != nil { @@ -125,20 +124,6 @@ func checkConnectionAckReceived(message []byte) (bool, error) { return wsMessage.Type == webSocketTypeConnAck, nil } -func forwardWebSocketData(respChan chan json.RawMessage, message []byte) error { - var wsMsg webSocketReceiveMessage - err := json.Unmarshal(message, &wsMsg) - if err != nil { - return err - } - switch wsMsg.Type { - case webSocketTypeNext, webSocketTypeError: - respChan <- wsMsg.Payload - default: - } - return nil -} - // formatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func formatCloseMessage(closeCode int, text string) []byte { diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 4a55eb3d..371fd1d3 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3094,26 +3094,20 @@ subscription count { } ` -// To close the connection, use [graphql.WebSocketClient.CloseWebSocket()] +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func count( ctx_ context.Context, client_ graphql.WebSocketClient, -) (dataChan_ chan countWsResponse, errChan_ chan error, err_ error) { +) (dataChan_ chan countWsResponse, subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "count", Query: count_Operation, } dataChan_ = make(chan countWsResponse) - respChan_ := make(chan json.RawMessage) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countForwardData) - errChan_, err_ = client_.DialWebSocket(ctx_, req_, respChan_) - if err_ != nil { - return nil, nil, err_ - } - go countForwardData(ctx_, dataChan_, respChan_, errChan_) - - return dataChan_, errChan_, err_ + return dataChan_, subscriptionID_, err_ } type countWsResponse struct { @@ -3122,37 +3116,27 @@ type countWsResponse struct { Errors error `json:"errors"` } -func countForwardData(ctx_ context.Context, dataChan_ chan countWsResponse, respChan chan json.RawMessage, errChan_ chan error) { - defer close(dataChan_) +func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp countWsResponse - for { - select { - case <-ctx_.Done(): - errChan_ <- errors.New("context was canceled") - return - default: - jsonRaw, more := <-respChan - if !more { - return - } - err := json.Unmarshal(jsonRaw, &gqlResp) - if err != nil { - errChan_ <- err - return - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRaw, &wsResp) - if err != nil { - errChan_ <- err - return - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan countWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan countWsResponse'") } + dataChan_ <- wsResp + return nil } // The mutation executed by createUser. diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 0ed4520d..53cfc58d 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -66,14 +66,17 @@ func TestSubscription(t *testing.T) { defer server.Close() wsClient := newRoundtripWebScoketClient(t, server.URL) - start := time.Now() - respChan, errChan, err := count(ctx, wsClient) + errChan, err := wsClient.StartWebSocket(ctx) + require.NoError(t, err) + + dataChan, subscriptionID, err := count(ctx, wsClient) require.NoError(t, err) defer wsClient.CloseWebSocket() counter := 0 + start := time.Now() for loop := true; loop; { select { - case resp, more := <-respChan: + case resp, more := <-dataChan: if !more { loop = false break @@ -81,11 +84,15 @@ func TestSubscription(t *testing.T) { require.NotNil(t, resp.Data) assert.Equal(t, counter, resp.Data.Count) require.Nil(t, resp.Errors) - loop = time.Since(start) < time.Second*2 + if time.Since(start) > time.Second*5 { + err = wsClient.Unsubscribe(subscriptionID) + require.NoError(t, err) + loop = false + } counter++ case err := <-errChan: require.NoError(t, err) - case <-time.After(time.Second * 5): + case <-time.After(time.Second * 10): require.NoError(t, fmt.Errorf("subscription timed out")) } } diff --git a/internal/integration/roundtrip.go b/internal/integration/roundtrip.go index 8b386c84..665c52f9 100644 --- a/internal/integration/roundtrip.go +++ b/internal/integration/roundtrip.go @@ -106,14 +106,22 @@ func (c *roundtripClient) MakeRequest(ctx context.Context, req *graphql.Request, return nil } -func (c *roundtripClient) DialWebSocket(ctx context.Context, req *graphql.Request, respChan chan json.RawMessage) (errChan chan error, err error) { - return c.wsWrapped.DialWebSocket(ctx, req, respChan) +func (c *roundtripClient) StartWebSocket(ctx context.Context) (errChan chan error, err error) { + return c.wsWrapped.StartWebSocket(ctx) } func (c *roundtripClient) CloseWebSocket() { c.wsWrapped.CloseWebSocket() } +func (c *roundtripClient) Subscribe(req *graphql.Request, interfaceChan interface{}, forwardDataFunc graphql.ForwardDataFunction) (string, error) { + return c.wsWrapped.Subscribe(req, interfaceChan, forwardDataFunc) +} + +func (c *roundtripClient) Unsubscribe(subscriptionID string) error { + return c.wsWrapped.Unsubscribe(subscriptionID) +} + func newRoundtripClients(t *testing.T, endpoint string) []graphql.Client { return []graphql.Client{newRoundtripClient(t, endpoint), newRoundtripGetClient(t, endpoint)} } From 495be60abb223bddd112a383ec9c01dfca158025 Mon Sep 17 00:00:00 2001 From: matthieu4294967296moineau Date: Mon, 9 Oct 2023 17:44:58 +0200 Subject: [PATCH 9/9] check ctx --- graphql/client.go | 2 +- graphql/websocket.go | 25 ++++++++++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/graphql/client.go b/graphql/client.go index 34882b1f..ee3fd743 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -349,7 +349,7 @@ func (w *webSocketClient) StartWebSocket(ctx context.Context) (errChan chan erro w.conn.Close() return nil, err } - go w.listenWebSocket() + go w.listenWebSocket(ctx) return w.errChan, err } diff --git a/graphql/websocket.go b/graphql/websocket.go index 2bd04ceb..4b87593b 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -1,6 +1,7 @@ package graphql import ( + "context" "encoding/binary" "encoding/json" "fmt" @@ -79,17 +80,23 @@ func (w *webSocketClient) waitForConnAck() error { return nil } -func (w *webSocketClient) listenWebSocket() { +func (w *webSocketClient) listenWebSocket(ctx context.Context) { for { - _, message, err := w.conn.ReadMessage() - if err != nil { - w.errChan <- err - return - } - err = w.forwardWebSocketData(message) - if err != nil { - w.errChan <- err + select { + case <-ctx.Done(): + w.errChan <- fmt.Errorf("context canceled") return + default: + _, message, err := w.conn.ReadMessage() + if err != nil { + w.errChan <- err + return + } + err = w.forwardWebSocketData(message) + if err != nil { + w.errChan <- err + return + } } } }