Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call middleware and directives for subscriptions #871

Merged
merged 1 commit into from Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion .golangci.yml
Expand Up @@ -31,7 +31,6 @@ linters:
- stylecheck
- typecheck
- unconvert
- unparam
- unused
- varcheck

Expand Down
26 changes: 14 additions & 12 deletions client/websocket.go
Expand Up @@ -8,7 +8,6 @@ import (
"strings"

"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/gqlerror"
)

const (
Expand Down Expand Up @@ -44,6 +43,13 @@ func (p *Client) Websocket(query string, options ...Option) *Subscription {
return p.WebsocketWithPayload(query, nil, options...)
}

// Grab a single response from a websocket based query
func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
sock := p.Websocket(query)
defer sock.Close()
return sock.Next(&resp)
}

func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
r, err := p.newRequest(query, options...)
if err != nil {
Expand Down Expand Up @@ -119,23 +125,19 @@ func (p *Client) WebsocketWithPayload(query string, initPayload map[string]inter
}
}

respDataRaw := map[string]interface{}{}
var respDataRaw Response
err = json.Unmarshal(op.Payload, &respDataRaw)
if err != nil {
return fmt.Errorf("decode: %s", err.Error())
}

if respDataRaw["errors"] != nil {
var errs []*gqlerror.Error
if err = unpack(respDataRaw["errors"], &errs); err != nil {
return err
}
if len(errs) > 0 {
return fmt.Errorf("errors: %s", errs)
}
}
// we want to unpack even if there is an error, so we can see partial responses
unpackErr := unpack(respDataRaw.Data, response)

return unpack(respDataRaw["data"], response)
if respDataRaw.Errors != nil {
return RawJsonError{respDataRaw.Errors}
}
return unpackErr
},
}
}
3 changes: 3 additions & 0 deletions codegen/field.go
Expand Up @@ -27,6 +27,7 @@ type Field struct {
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
Stream bool // does this field return a channel?
Directives []*Directive
}

Expand Down Expand Up @@ -84,6 +85,8 @@ func (b *builder) bindField(obj *Object, f *Field) error {
}
}()

f.Stream = obj.Stream

switch {
case f.Name == "__schema":
f.GoFieldType = GoFieldMethod
Expand Down
130 changes: 56 additions & 74 deletions codegen/field.gotpl
@@ -1,29 +1,59 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{
Field: field,
Args: nil,
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(){{ end }}graphql.Marshaler) {
{{- $null := "graphql.Null" }}
{{- if $object.Stream }}
{{- $null = "nil" }}
{{- end }}
ctx = ec.Tracer.StartFieldExecution(ctx, field)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = {{ $null }}
}
ec.Tracer.EndFieldExecution(ctx)
}()
rctx := &graphql.ResolverContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
}
ctx = graphql.WithResolverContext(ctx, rctx)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
rctx.Args = args
{{- end }}
ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx)
{{- if $.Directives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return nil
}
{{- end }}
// FIXME: subscriptions are missing request middleware stack https://github.com/99designs/gqlgen/issues/259
// and Tracer stack
rctx := ctx
results, err := ec.resolvers.{{ $field.ShortInvocation }}
if err != nil {
ec.Error(ctx, err)
return nil
return {{ $null }}
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !ec.HasError(rctx) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return {{ $null }}
}
{{- if $object.Stream }}
return func() graphql.Marshaler {
res, ok := <-results
res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}})
if !ok {
return nil
}
Expand All @@ -35,61 +65,13 @@
w.Write([]byte{'}'})
})
}
}
{{ else }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret graphql.Marshaler) {
ctx = ec.Tracer.StartFieldExecution(ctx, field)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
ec.Tracer.EndFieldExecution(ctx)
}()
rctx := &graphql.ResolverContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
}
ctx = graphql.WithResolverContext(ctx, rctx)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
rctx.Args = args
{{- end }}
ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx)
{{- if $.Directives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !ec.HasError(rctx) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return graphql.Null
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
rctx.Result = res
ctx = ec.Tracer.StartFieldChildExecution(ctx)
return ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res)
}
{{ end }}
{{- end }}
}

{{- end }}{{- end}}

Expand All @@ -107,10 +89,10 @@
if tmp == nil {
return nil, nil
}
if data, ok := tmp.({{ .TypeReference.GO | ref }}) ; ok {
if data, ok := tmp.({{if .Stream}}<-chan {{end}}{{ .TypeReference.GO | ref }}) ; ok {
return data, nil
}
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{ .TypeReference.GO }}`, tmp)
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{if .Stream}}<-chan {{end}}{{ .TypeReference.GO }}`, tmp)
{{- else -}}
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
Expand All @@ -122,9 +104,9 @@
return ec.resolvers.{{ .ShortInvocation }}
{{- else if .IsMap -}}
switch v := {{.GoReceiverName}}[{{.Name|quote}}].(type) {
case {{.TypeReference.GO | ref}}:
case {{if .Stream}}<-chan {{end}}{{.TypeReference.GO | ref}}:
return v, nil
case {{.TypeReference.Elem.GO | ref}}:
case {{if .Stream}}<-chan {{end}}{{.TypeReference.Elem.GO | ref}}:
return &v, nil
case nil:
return ({{.TypeReference.GO | ref}})(nil), nil
Expand Down
7 changes: 7 additions & 0 deletions codegen/testserver/directive.graphql
Expand Up @@ -21,6 +21,13 @@ extend type Query {
directiveUnimplemented: String @unimplemented
}

extend type Subscription {
directiveArg(arg: String! @length(min:1, max: 255, message: "invalid length")): String
directiveNullableArg(arg: Int @range(min:0), arg2: Int @range, arg3: String @toNull): String
directiveDouble: String @directive1 @directive2
directiveUnimplemented: String @unimplemented
}

input InputDirectives {
text: String! @length(min: 0, max: 7, message: "not valid")
nullableText: String @toNull
Expand Down