Skip to content

Commit

Permalink
Call middleware and directives for subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Sep 22, 2019
1 parent 17f32d2 commit 24ad9ad
Show file tree
Hide file tree
Showing 9 changed files with 684 additions and 140 deletions.
26 changes: 14 additions & 12 deletions client/websocket.go
Expand Up @@ -8,7 +8,6 @@ import (
"strings"

"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/gqlerror"
)

const (
Expand Down Expand Up @@ -44,6 +43,13 @@ func (p *Client) Websocket(query string, options ...Option) *Subscription {
return p.WebsocketWithPayload(query, nil, options...)
}

// Grab a single response from a websocket based query
func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
sock := p.Websocket(query)
defer sock.Close()
return sock.Next(&resp)
}

func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
r, err := p.newRequest(query, options...)
if err != nil {
Expand Down Expand Up @@ -119,23 +125,19 @@ func (p *Client) WebsocketWithPayload(query string, initPayload map[string]inter
}
}

respDataRaw := map[string]interface{}{}
var respDataRaw Response
err = json.Unmarshal(op.Payload, &respDataRaw)
if err != nil {
return fmt.Errorf("decode: %s", err.Error())
}

if respDataRaw["errors"] != nil {
var errs []*gqlerror.Error
if err = unpack(respDataRaw["errors"], &errs); err != nil {
return err
}
if len(errs) > 0 {
return fmt.Errorf("errors: %s", errs)
}
}
// we want to unpack even if there is an error, so we can see partial responses
unpackErr := unpack(respDataRaw.Data, response)

return unpack(respDataRaw["data"], response)
if respDataRaw.Errors != nil {
return RawJsonError{respDataRaw.Errors}
}
return unpackErr
},
}
}
3 changes: 3 additions & 0 deletions codegen/field.go
Expand Up @@ -27,6 +27,7 @@ type Field struct {
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
Stream bool // does this field return a channel?
Directives []*Directive
}

Expand Down Expand Up @@ -84,6 +85,8 @@ func (b *builder) bindField(obj *Object, f *Field) error {
}
}()

f.Stream = obj.Stream

switch {
case f.Name == "__schema":
f.GoFieldType = GoFieldMethod
Expand Down
130 changes: 56 additions & 74 deletions codegen/field.gotpl
@@ -1,29 +1,59 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{
Field: field,
Args: nil,
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(){{ end }}graphql.Marshaler) {
{{- $null := "graphql.Null" }}
{{- if $object.Stream }}
{{- $null = "nil" }}
{{- end }}
ctx = ec.Tracer.StartFieldExecution(ctx, field)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = {{ $null }}
}
ec.Tracer.EndFieldExecution(ctx)
}()
rctx := &graphql.ResolverContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
}
ctx = graphql.WithResolverContext(ctx, rctx)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
rctx.Args = args
{{- end }}
ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx)
{{- if $.Directives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return nil
}
{{- end }}
// FIXME: subscriptions are missing request middleware stack https://github.com/99designs/gqlgen/issues/259
// and Tracer stack
rctx := ctx
results, err := ec.resolvers.{{ $field.ShortInvocation }}
if err != nil {
ec.Error(ctx, err)
return nil
return {{ $null }}
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !ec.HasError(rctx) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return {{ $null }}
}
{{- if $object.Stream }}
return func() graphql.Marshaler {
res, ok := <-results
res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}})
if !ok {
return nil
}
Expand All @@ -35,61 +65,13 @@
w.Write([]byte{'}'})
})
}
}
{{ else }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret graphql.Marshaler) {
ctx = ec.Tracer.StartFieldExecution(ctx, field)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
ec.Tracer.EndFieldExecution(ctx)
}()
rctx := &graphql.ResolverContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
}
ctx = graphql.WithResolverContext(ctx, rctx)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
rctx.Args = args
{{- end }}
ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx)
{{- if $.Directives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !ec.HasError(rctx) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return graphql.Null
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
rctx.Result = res
ctx = ec.Tracer.StartFieldChildExecution(ctx)
return ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res)
}
{{ end }}
{{- end }}
}

{{- end }}{{- end}}

Expand All @@ -107,10 +89,10 @@
if tmp == nil {
return nil, nil
}
if data, ok := tmp.({{ .TypeReference.GO | ref }}) ; ok {
if data, ok := tmp.({{if .Stream}}<-chan {{end}}{{ .TypeReference.GO | ref }}) ; ok {
return data, nil
}
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{ .TypeReference.GO }}`, tmp)
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{if .Stream}}<-chan {{end}}{{ .TypeReference.GO }}`, tmp)
{{- else -}}
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
Expand All @@ -122,9 +104,9 @@
return ec.resolvers.{{ .ShortInvocation }}
{{- else if .IsMap -}}
switch v := {{.GoReceiverName}}[{{.Name|quote}}].(type) {
case {{.TypeReference.GO | ref}}:
case {{if .Stream}}<-chan {{end}}{{.TypeReference.GO | ref}}:
return v, nil
case {{.TypeReference.Elem.GO | ref}}:
case {{if .Stream}}<-chan {{end}}{{.TypeReference.Elem.GO | ref}}:
return &v, nil
case nil:
return ({{.TypeReference.GO | ref}})(nil), nil
Expand Down
7 changes: 7 additions & 0 deletions codegen/testserver/directive.graphql
Expand Up @@ -21,6 +21,13 @@ extend type Query {
directiveUnimplemented: String @unimplemented
}

extend type Subscription {
directiveArg(arg: String! @length(min:1, max: 255, message: "invalid length")): String
directiveNullableArg(arg: Int @range(min:0), arg2: Int @range, arg3: String @toNull): String
directiveDouble: String @directive1 @directive2
directiveUnimplemented: String @unimplemented
}

input InputDirectives {
text: String! @length(min: 0, max: 7, message: "not valid")
nullableText: String @toNull
Expand Down

0 comments on commit 24ad9ad

Please sign in to comment.