From 238a7e2fe255b6e9bd2106f3d0663ef7909fc62e Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Fri, 24 Aug 2018 15:51:58 +1200 Subject: [PATCH] Add complexity support to codegen, handler --- codegen/config.go | 5 ++-- codegen/object.go | 48 ++++++++++++++++++++++++------- codegen/object_build.go | 15 ++++++---- codegen/templates/data.go | 2 +- codegen/templates/generated.gotpl | 48 ++++++++++++++++++++++++++++++- graphql/exec.go | 1 + handler/graphql.go | 28 ++++++++++++++---- handler/stub.go | 4 +++ 8 files changed, 126 insertions(+), 25 deletions(-) diff --git a/codegen/config.go b/codegen/config.go index 8e2a5e74b3..36debec99c 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -83,8 +83,9 @@ type TypeMapEntry struct { } type TypeMapField struct { - Resolver bool `yaml:"resolver"` - FieldName string `yaml:"fieldName"` + Resolver bool `yaml:"resolver"` + Complexity bool `yaml:"complexity"` + FieldName string `yaml:"fieldName"` } func (c *PackageConfig) normalize() error { diff --git a/codegen/object.go b/codegen/object.go index 93d1952102..f92d020599 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -33,15 +33,16 @@ type Object struct { type Field struct { *Type - GQLName string // The name of the field in graphql - GoFieldType GoFieldType // The field type in go, if any - GoReceiverName string // The name of method & var receiver in go, if any - GoFieldName string // The name of the method or var in go, if any - Args []FieldArgument // A list of arguments to be passed to this field - ForceResolver bool // Should be emit Resolver method - NoErr bool // If this is bound to a go method, does that method have an error as the second argument - Object *Object // A link back to the parent object - Default interface{} // The default value + GQLName string // The name of the field in graphql + GoFieldType GoFieldType // The field type in go, if any + GoReceiverName string // The name of method & var receiver in go, if any + GoFieldName string // The name of the method or var in go, if any + Args []FieldArgument // A list of arguments to be passed to this field + ForceResolver bool // Should be emit Resolver method + CustomComplexity bool // Uses a custom complexity calculation + NoErr bool // If this is bound to a go method, does that method have an error as the second argument + Object *Object // A link back to the parent object + Default interface{} // The default value } type FieldArgument struct { @@ -81,6 +82,15 @@ func (o *Object) IsConcurrent() bool { return false } +func (o *Object) HasComplexity() bool { + for _, f := range o.Fields { + if f.CustomComplexity { + return true + } + } + return false +} + func (f *Field) IsResolver() bool { return f.GoFieldName == "" } @@ -165,6 +175,24 @@ func (f *Field) ResolverDeclaration() string { return res } +func (f *Field) ComplexitySignature() string { + res := fmt.Sprintf("func(childComplexity int") + for _, arg := range f.Args { + res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature()) + } + res += ") int" + return res +} + +func (f *Field) ComplexityArgs() string { + var args []string + for _, arg := range f.Args { + args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")") + } + + return strings.Join(args, ", ") +} + func (f *Field) CallArgs() string { var args []string @@ -227,7 +255,7 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ ctx := graphql.WithResolverContext(ctx, rctx) {{- end}} {{.arr}} = append({{.arr}}, func() graphql.Marshaler { - {{ .next }} + {{ .next }} }()) } return {{.arr}}`, map[string]interface{}{ diff --git a/codegen/object_build.go b/codegen/object_build.go index 95602a9138..9b3931f1c3 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -135,11 +135,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I } var forceResolver bool + var customComplexity bool var goName string if entryExists { if typeField, ok := typeEntry.Fields[field.Name]; ok { goName = typeField.FieldName forceResolver = typeField.Resolver + customComplexity = typeField.Complexity } } @@ -168,12 +170,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I } obj.Fields = append(obj.Fields, Field{ - GQLName: field.Name, - Type: types.getType(field.Type), - Args: args, - Object: obj, - GoFieldName: goName, - ForceResolver: forceResolver, + GQLName: field.Name, + Type: types.getType(field.Type), + Args: args, + Object: obj, + GoFieldName: goName, + ForceResolver: forceResolver, + CustomComplexity: customComplexity, }) } diff --git a/codegen/templates/data.go b/codegen/templates/data.go index db2816c023..7718c5b6ab 100644 --- a/codegen/templates/data.go +++ b/codegen/templates/data.go @@ -3,7 +3,7 @@ package templates var data = map[string]string{ "args.gotpl": "\t{{- if . }}args := map[string]interface{}{} {{end}}\n\t{{- range $i, $arg := . }}\n\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\tif tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {\n\t\t\tvar err error\n\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\tif err != nil {\n\t\t\t\tec.Error(ctx, err)\n\t\t\t\t{{- if $arg.Stream }}\n\t\t\t\t\treturn nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\t\targs[{{$arg.GQLName|quote}}] = arg{{$i}}\n\t{{- end -}}\n", "field.gotpl": "{{ $field := . }}\n{{ $object := $field.Object }}\n\n{{- if $object.Stream }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {\n\t\t{{- if $field.Args }}\n\t\t\trawArgs := field.ArgumentMap(ec.Variables)\n\t\t\t{{ template \"args.gotpl\" $field.Args }}\n\t\t{{- end }}\n\t\tctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{\n\t\t\tField: field,\n\t\t})\n\t\tresults, err := ec.resolvers.{{ $field.ShortInvocation }}\n\t\tif err != nil {\n\t\t\tec.Error(ctx, err)\n\t\t\treturn nil\n\t\t}\n\t\treturn func() graphql.Marshaler {\n\t\t\tres, ok := <-results\n\t\t\tif !ok {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tvar out graphql.OrderedMap\n\t\t\tout.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())\n\t\t\treturn &out\n\t\t}\n\t}\n{{ else }}\n\t// nolint: vetshadow\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {\n\t\t{{- if $field.Args }}\n\t\t\trawArgs := field.ArgumentMap(ec.Variables)\n\t\t\t{{ template \"args.gotpl\" $field.Args }}\n\t\t{{- end }}\n\t\trctx := &graphql.ResolverContext{\n\t\t\tObject: {{$object.GQLType|quote}},\n\t\t\tArgs: {{if $field.Args }}args{{else}}nil{{end}},\n\t\t\tField: field,\n\t\t}\n\t\tctx = graphql.WithResolverContext(ctx, rctx)\n\t\tresTmp := ec.FieldMiddleware(ctx, func(ctx context.Context) (interface{}, error) {\n\t\t\t{{- if $field.IsResolver }}\n\t\t\t\treturn ec.resolvers.{{ $field.ShortInvocation }}\n\t\t\t{{- else if $field.IsMethod }}\n\t\t\t\t{{- if $field.NoErr }}\n\t\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}({{ $field.CallArgs }}), nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}({{ $field.CallArgs }})\n\t\t\t\t{{- end }}\n\t\t\t{{- else if $field.IsVariable }}\n\t\t\t\treturn {{$field.GoReceiverName}}.{{$field.GoFieldName}}, nil\n\t\t\t{{- end }}\n\t\t})\n\t\tif resTmp == nil {\n\t\t\t{{- if $field.ASTType.NonNull }}\n\t\t\t\tif !ec.HasError(rctx) {\n\t\t\t\t\tec.Errorf(ctx, \"must not be null\")\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\treturn graphql.Null\n\t\t}\n\t\tres := resTmp.({{$field.Signature}})\n\t\trctx.Result = res\n\t\t{{ $field.WriteJson }}\n\t}\n{{ end }}\n", - "generated.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.\nfunc NewExecutableSchema(cfg Config) graphql.ExecutableSchema {\n\treturn &executableSchema{\n\t\tresolvers: cfg.Resolvers,\n\t\tdirectives: cfg.Directives,\n\t}\n}\n\ntype Config struct {\n\tResolvers ResolverRoot\n\tDirectives DirectiveRoot\n}\n\ntype ResolverRoot interface {\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers -}}\n\t\t{{$object.GQLType}}() {{$object.GQLType}}Resolver\n\t{{ end }}\n{{- end }}\n}\n\ntype DirectiveRoot struct {\n{{ range $directive := .Directives }}\n\t{{ $directive.Declaration }}\n{{ end }}\n}\n\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers }}\n\t\ttype {{$object.GQLType}}Resolver interface {\n\t\t{{ range $field := $object.Fields -}}\n\t\t\t{{ $field.ShortResolverDeclaration }}\n\t\t{{ end }}\n\t\t}\n\t{{- end }}\n{{- end }}\n\ntype executableSchema struct {\n\tresolvers ResolverRoot\n\tdirectives DirectiveRoot\n}\n\nfunc (e *executableSchema) Schema() *ast.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.QueryRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"queries are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.MutationRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"mutations are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.SelectionSet)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\t\tbuf.Reset()\n\t\t\t\tdata := next()\n\n\t\t\t\tif data == nil {\n\t\t\t\t\treturn nil\n\t\t\t\t}\n\t\t\t\tdata.MarshalGQL(&buf)\n\t\t\t\treturn buf.Bytes()\n\t\t\t})\n\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf,\n\t\t\t\tErrors: ec.Errors,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(graphql.ErrorResponse(ctx, \"subscriptions are not supported\"))\n\t{{- end }}\n}\n\ntype executionContext struct {\n\t*graphql.RequestContext\n\t*executableSchema\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nfunc (ec *executionContext) FieldMiddleware(ctx context.Context, next graphql.Resolver) (ret interface{}) {\n\tdefer func() {\n\t\tif r := recover(); r != nil {\n\t\t\tec.Error(ctx, ec.Recover(ctx, r))\n\t\t\tret = nil\n\t\t}\n\t}()\n\t{{- if .Directives }}\n\trctx := graphql.GetResolverContext(ctx)\n\tfor _, d := range rctx.Field.Definition.Directives {\n\t\tswitch d.Name {\n\t\t{{- range $directive := .Directives }}\n\t\tcase \"{{$directive.Name}}\":\n\t\t\tif ec.directives.{{$directive.Name|ucFirst}} != nil {\n\t\t\t\t{{- if $directive.Args }}\n\t\t\t\t\trawArgs := d.ArgumentMap(ec.Variables)\n\t\t\t\t\t{{ template \"args.gotpl\" $directive.Args }}\n\t\t\t\t{{- end }}\n\t\t\t\tn := next\n\t\t\t\tnext = func(ctx context.Context) (interface{}, error) {\n\t\t\t\t\treturn ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})\n\t\t\t\t}\n\t\t\t}\n\t\t{{- end }}\n\t\t}\n\t}\n\t{{- end }}\n\tres, err := ec.ResolverMiddleware(ctx, next)\n\tif err != nil {\n\t\tec.Error(ctx, err)\n\t\treturn nil\n\t}\n\treturn res\n}\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\treturn introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name])\n}\n\nvar parsedSchema = gqlparser.MustLoadSchema(\n\t&ast.Source{Name: {{.SchemaFilename|quote}}, Input: {{.SchemaRaw|rawQuote}}},\n)\n", + "generated.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.\nfunc NewExecutableSchema(cfg Config) graphql.ExecutableSchema {\n\treturn &executableSchema{\n\t\tresolvers: cfg.Resolvers,\n\t\tdirectives: cfg.Directives,\n\t\tcomplexity: cfg.Complexity,\n\t}\n}\n\ntype Config struct {\n\tResolvers ResolverRoot\n\tDirectives DirectiveRoot\n\tComplexity ComplexityRoot\n}\n\ntype ResolverRoot interface {\n{{- range $object := .Objects -}}\n\t{{ if $object.HasResolvers -}}\n\t\t{{$object.GQLType}}() {{$object.GQLType}}Resolver\n\t{{ end }}\n{{- end }}\n}\n\ntype DirectiveRoot struct {\n{{ range $directive := .Directives }}\n\t{{ $directive.Declaration }}\n{{ end }}\n}\n\ntype ComplexityRoot struct {\n{{ range $object := .Objects }}\n\t{{ if $object.HasComplexity }}\n\t\t{{ $object.GoType }} struct {\n\t\t{{ range $field := $object.Fields }}\n\t\t\t{{ if $field.CustomComplexity }}\n\t\t\t\t{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}\n\t\t\t{{ end }}\n\t\t{{ end }}\n\t\t}\n\t{{ end }}\n{{ end }}\n}\n\n{{ range $object := .Objects -}}\n\t{{ if $object.HasResolvers }}\n\t\ttype {{$object.GQLType}}Resolver interface {\n\t\t{{ range $field := $object.Fields -}}\n\t\t\t{{ $field.ShortResolverDeclaration }}\n\t\t{{ end }}\n\t\t}\n\t{{- end }}\n{{- end }}\n\ntype executableSchema struct {\n\tresolvers ResolverRoot\n\tdirectives DirectiveRoot\n\tcomplexity ComplexityRoot\n}\n\nfunc (e *executableSchema) Schema() *ast.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {\n\tswitch typeName + \".\" + field {\n\t{{ range $object := .Objects }}\n\t\t{{ range $field := $object.Fields }}\n\t\t\t{{ if $field.CustomComplexity }}\n\t\t\t\tcase \"{{$object.GQLType}}.{{$field.GQLName}}\":\n\t\t\t\t if e.complexity.{{$object.GoType}}.{{$field.GoFieldName}} == nil {\n\t\t\t\t\t\tbreak\n\t\t\t\t\t}\n\t\t\t\t\t{{ if . }}args := map[string]interface{}{} {{end}}\n\t\t\t\t\t{{ range $i, $arg := $field.Args }}\n\t\t\t\t\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\t\t\t\t\tif tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {\n\t\t\t\t\t\t\tvar err error\n\t\t\t\t\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\t\t\t\t\tif err != nil {\n\t\t\t\t\t\t\t\treturn 0, false\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\targs[{{$arg.GQLName|quote}}] = arg{{$i}}\n\t\t\t\t\t{{ end }}\n\t\t\t\t\treturn e.complexity.{{$object.GoType}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true\n\t\t\t{{ end }}\n\t\t{{ end }}\n\t{{ end }}\n\t}\n\treturn 0, false\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.QueryRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"queries are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\tdata := ec._{{.MutationRoot.GQLType}}(ctx, op.SelectionSet)\n\t\t\tvar buf bytes.Buffer\n\t\t\tdata.MarshalGQL(&buf)\n\t\t\treturn buf.Bytes()\n\t\t})\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf,\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn graphql.ErrorResponse(ctx, \"mutations are not supported\")\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.SelectionSet)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf := ec.RequestMiddleware(ctx, func(ctx context.Context) []byte {\n\t\t\t\tbuf.Reset()\n\t\t\t\tdata := next()\n\n\t\t\t\tif data == nil {\n\t\t\t\t\treturn nil\n\t\t\t\t}\n\t\t\t\tdata.MarshalGQL(&buf)\n\t\t\t\treturn buf.Bytes()\n\t\t\t})\n\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf,\n\t\t\t\tErrors: ec.Errors,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(graphql.ErrorResponse(ctx, \"subscriptions are not supported\"))\n\t{{- end }}\n}\n\ntype executionContext struct {\n\t*graphql.RequestContext\n\t*executableSchema\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nfunc (ec *executionContext) FieldMiddleware(ctx context.Context, next graphql.Resolver) (ret interface{}) {\n\tdefer func() {\n\t\tif r := recover(); r != nil {\n\t\t\tec.Error(ctx, ec.Recover(ctx, r))\n\t\t\tret = nil\n\t\t}\n\t}()\n\t{{- if .Directives }}\n\trctx := graphql.GetResolverContext(ctx)\n\tfor _, d := range rctx.Field.Definition.Directives {\n\t\tswitch d.Name {\n\t\t{{- range $directive := .Directives }}\n\t\tcase \"{{$directive.Name}}\":\n\t\t\tif ec.directives.{{$directive.Name|ucFirst}} != nil {\n\t\t\t\t{{- if $directive.Args }}\n\t\t\t\t\trawArgs := d.ArgumentMap(ec.Variables)\n\t\t\t\t\t{{ template \"args.gotpl\" $directive.Args }}\n\t\t\t\t{{- end }}\n\t\t\t\tn := next\n\t\t\t\tnext = func(ctx context.Context) (interface{}, error) {\n\t\t\t\t\treturn ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})\n\t\t\t\t}\n\t\t\t}\n\t\t{{- end }}\n\t\t}\n\t}\n\t{{- end }}\n\tres, err := ec.ResolverMiddleware(ctx, next)\n\tif err != nil {\n\t\tec.Error(ctx, err)\n\t\treturn nil\n\t}\n\treturn res\n}\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\treturn introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name])\n}\n\nvar parsedSchema = gqlparser.MustLoadSchema(\n\t&ast.Source{Name: {{.SchemaFilename|quote}}, Input: {{.SchemaRaw|rawQuote}}},\n)\n", "input.gotpl": "\t{{- if .IsMarshaled }}\n\tfunc Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) {\n\t\tvar it {{.FullName}}\n\t\tvar asMap = v.(map[string]interface{})\n\t\t{{ range $field := .Fields}}\n\t\t\t{{- if $field.Default}}\n\t\t\t\tif _, present := asMap[{{$field.GQLName|quote}}] ; !present {\n\t\t\t\t\tasMap[{{$field.GQLName|quote}}] = {{ $field.Default | dump }}\n\t\t\t\t}\n\t\t\t{{- end}}\n\t\t{{- end }}\n\n\t\tfor k, v := range asMap {\n\t\t\tswitch k {\n\t\t\t{{- range $field := .Fields }}\n\t\t\tcase {{$field.GQLName|quote}}:\n\t\t\t\tvar err error\n\t\t\t\t{{ $field.Unmarshal (print \"it.\" $field.GoFieldName) \"v\" }}\n\t\t\t\tif err != nil {\n\t\t\t\t\treturn it, err\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\n\t\treturn it, nil\n\t}\n\t{{- end }}\n", "interface.gotpl": "{{- $interface := . }}\n\nfunc (ec *executionContext) _{{$interface.GQLType}}(ctx context.Context, sel ast.SelectionSet, obj *{{$interface.FullName}}) graphql.Marshaler {\n\tswitch obj := (*obj).(type) {\n\tcase nil:\n\t\treturn graphql.Null\n\t{{- range $implementor := $interface.Implementors }}\n\t\t{{- if $implementor.ValueReceiver }}\n\t\t\tcase {{$implementor.FullName}}:\n\t\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, &obj)\n\t\t{{- end}}\n\t\tcase *{{$implementor.FullName}}:\n\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, obj)\n\t{{- end }}\n\tdefault:\n\t\tpanic(fmt.Errorf(\"unexpected type %T\", obj))\n\t}\n}\n", "models.gotpl": "// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n{{ range $model := .Models }}\n\t{{- if .IsInterface }}\n\t\ttype {{.GoType}} interface {}\n\t{{- else }}\n\t\ttype {{.GoType}} struct {\n\t\t\t{{- range $field := .Fields }}\n\t\t\t\t{{- if $field.GoFieldName }}\n\t\t\t\t\t{{ $field.GoFieldName }} {{$field.Signature}} `json:\"{{$field.GQLName}}\"`\n\t\t\t\t{{- else }}\n\t\t\t\t\t{{ $field.GoFKName }} {{$field.GoFKType}}\n\t\t\t\t{{- end }}\n\t\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{- end}}\n\n{{ range $enum := .Enums }}\n\ttype {{.GoType}} string\n\tconst (\n\t{{ range $value := .Values -}}\n\t\t{{with .Description}} {{.|prefixLines \"// \"}} {{end}}\n\t\t{{$enum.GoType}}{{ .Name|toCamel }} {{$enum.GoType}} = {{.Name|quote}}\n\t{{- end }}\n\t)\n\n\tfunc (e {{.GoType}}) IsValid() bool {\n\t\tswitch e {\n\t\tcase {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.GoType }}{{ $element.Name|toCamel }}{{end}}:\n\t\t\treturn true\n\t\t}\n\t\treturn false\n\t}\n\n\tfunc (e {{.GoType}}) String() string {\n\t\treturn string(e)\n\t}\n\n\tfunc (e *{{.GoType}}) UnmarshalGQL(v interface{}) error {\n\t\tstr, ok := v.(string)\n\t\tif !ok {\n\t\t\treturn fmt.Errorf(\"enums must be strings\")\n\t\t}\n\n\t\t*e = {{.GoType}}(str)\n\t\tif !e.IsValid() {\n\t\t\treturn fmt.Errorf(\"%s is not a valid {{.GQLType}}\", str)\n\t\t}\n\t\treturn nil\n\t}\n\n\tfunc (e {{.GoType}}) MarshalGQL(w io.Writer) {\n\t\tfmt.Fprint(w, strconv.Quote(e.String()))\n\t}\n\n{{- end }}\n", diff --git a/codegen/templates/generated.gotpl b/codegen/templates/generated.gotpl index c6b093db6e..7f258aa314 100644 --- a/codegen/templates/generated.gotpl +++ b/codegen/templates/generated.gotpl @@ -13,12 +13,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema { return &executableSchema{ resolvers: cfg.Resolvers, directives: cfg.Directives, + complexity: cfg.Complexity, } } type Config struct { Resolvers ResolverRoot Directives DirectiveRoot + Complexity ComplexityRoot } type ResolverRoot interface { @@ -35,7 +37,21 @@ type DirectiveRoot struct { {{ end }} } -{{- range $object := .Objects -}} +type ComplexityRoot struct { +{{ range $object := .Objects }} + {{ if $object.HasComplexity }} + {{ $object.GoType }} struct { + {{ range $field := $object.Fields }} + {{ if $field.CustomComplexity }} + {{ $field.GoFieldName }} {{ $field.ComplexitySignature }} + {{ end }} + {{ end }} + } + {{ end }} +{{ end }} +} + +{{ range $object := .Objects -}} {{ if $object.HasResolvers }} type {{$object.GQLType}}Resolver interface { {{ range $field := $object.Fields -}} @@ -48,12 +64,42 @@ type DirectiveRoot struct { type executableSchema struct { resolvers ResolverRoot directives DirectiveRoot + complexity ComplexityRoot } func (e *executableSchema) Schema() *ast.Schema { return parsedSchema } +func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) { + switch typeName + "." + field { + {{ range $object := .Objects }} + {{ range $field := $object.Fields }} + {{ if $field.CustomComplexity }} + case "{{$object.GQLType}}.{{$field.GQLName}}": + if e.complexity.{{$object.GoType}}.{{$field.GoFieldName}} == nil { + break + } + {{ if . }}args := map[string]interface{}{} {{end}} + {{ range $i, $arg := $field.Args }} + var arg{{$i}} {{$arg.Signature }} + if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok { + var err error + {{$arg.Unmarshal (print "arg" $i) "tmp" }} + if err != nil { + return 0, false + } + } + args[{{$arg.GQLName|quote}}] = arg{{$i}} + {{ end }} + return e.complexity.{{$object.GoType}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true + {{ end }} + {{ end }} + {{ end }} + } + return 0, false +} + func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { {{- if .QueryRoot }} ec := executionContext{graphql.GetRequestContext(ctx), e} diff --git a/graphql/exec.go b/graphql/exec.go index 387c34cc09..9beb314907 100644 --- a/graphql/exec.go +++ b/graphql/exec.go @@ -10,6 +10,7 @@ import ( type ExecutableSchema interface { Schema() *ast.Schema + Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) Query(ctx context.Context, op *ast.OperationDefinition) *Response Mutation(ctx context.Context, op *ast.OperationDefinition) *Response Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response diff --git a/handler/graphql.go b/handler/graphql.go index 0485af865b..50d7a09450 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" + "github.com/99designs/gqlgen/complexity" "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" "github.com/vektah/gqlparser" @@ -23,11 +24,12 @@ type params struct { } type Config struct { - upgrader websocket.Upgrader - recover graphql.RecoverFunc - errorPresenter graphql.ErrorPresenterFunc - resolverHook graphql.FieldMiddleware - requestHook graphql.RequestMiddleware + upgrader websocket.Upgrader + recover graphql.RecoverFunc + errorPresenter graphql.ErrorPresenterFunc + resolverHook graphql.FieldMiddleware + requestHook graphql.RequestMiddleware + complexityLimit int } func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -74,6 +76,14 @@ func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { } } +// ComplexityLimit sets a maximum query complexity that is allowed to be executed. +// If a query is submitted that exceeds the limit, a 422 status code will be returned. +func ComplexityLimit(limit int) Option { + return func(cfg *Config) { + cfg.complexityLimit = limit + } +} + // ResolverMiddleware allows you to define a function that will be called around every resolver, // useful for tracing and logging. func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { @@ -184,6 +194,14 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } }() + if cfg.complexityLimit > 0 { + queryComplexity := complexity.Calculate(exec, op, vars) + if queryComplexity > cfg.complexityLimit { + sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit) + return + } + } + switch op.Operation { case ast.Query: b, err := json.Marshal(exec.Query(ctx, op)) diff --git a/handler/stub.go b/handler/stub.go index a5b1a19b77..d237e18892 100644 --- a/handler/stub.go +++ b/handler/stub.go @@ -25,6 +25,10 @@ func (e *executableSchemaStub) Schema() *ast.Schema { `}) } +func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { + return 0, false +} + func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { return &graphql.Response{Data: []byte(`{"name":"test"}`)} }