Skip to content

Commit

Permalink
Add complexity support to codegen, handler
Browse files Browse the repository at this point in the history
  • Loading branch information
edsrzf committed Aug 24, 2018
1 parent 95ed529 commit 238a7e2
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 25 deletions.
5 changes: 3 additions & 2 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ type TypeMapEntry struct {
}

type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
Resolver bool `yaml:"resolver"`
Complexity bool `yaml:"complexity"`
FieldName string `yaml:"fieldName"`
}

func (c *PackageConfig) normalize() error {
Expand Down
48 changes: 38 additions & 10 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ type Object struct {
type Field struct {
*Type

GQLName string // The name of the field in graphql
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
Args []FieldArgument // A list of arguments to be passed to this field
ForceResolver bool // Should be emit Resolver method
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
GQLName string // The name of the field in graphql
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
Args []FieldArgument // A list of arguments to be passed to this field
ForceResolver bool // Should be emit Resolver method
CustomComplexity bool // Uses a custom complexity calculation
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
}

type FieldArgument struct {
Expand Down Expand Up @@ -81,6 +82,15 @@ func (o *Object) IsConcurrent() bool {
return false
}

func (o *Object) HasComplexity() bool {
for _, f := range o.Fields {
if f.CustomComplexity {
return true
}
}
return false
}

func (f *Field) IsResolver() bool {
return f.GoFieldName == ""
}
Expand Down Expand Up @@ -165,6 +175,24 @@ func (f *Field) ResolverDeclaration() string {
return res
}

func (f *Field) ComplexitySignature() string {
res := fmt.Sprintf("func(childComplexity int")
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
}
res += ") int"
return res
}

func (f *Field) ComplexityArgs() string {
var args []string
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
}

return strings.Join(args, ", ")
}

func (f *Field) CallArgs() string {
var args []string

Expand Down Expand Up @@ -227,7 +255,7 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ
ctx := graphql.WithResolverContext(ctx, rctx)
{{- end}}
{{.arr}} = append({{.arr}}, func() graphql.Marshaler {
{{ .next }}
{{ .next }}
}())
}
return {{.arr}}`, map[string]interface{}{
Expand Down
15 changes: 9 additions & 6 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I
}

var forceResolver bool
var customComplexity bool
var goName string
if entryExists {
if typeField, ok := typeEntry.Fields[field.Name]; ok {
goName = typeField.FieldName
forceResolver = typeField.Resolver
customComplexity = typeField.Complexity
}
}

Expand Down Expand Up @@ -168,12 +170,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I
}

obj.Fields = append(obj.Fields, Field{
GQLName: field.Name,
Type: types.getType(field.Type),
Args: args,
Object: obj,
GoFieldName: goName,
ForceResolver: forceResolver,
GQLName: field.Name,
Type: types.getType(field.Type),
Args: args,
Object: obj,
GoFieldName: goName,
ForceResolver: forceResolver,
CustomComplexity: customComplexity,
})
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/data.go

Large diffs are not rendered by default.

48 changes: 47 additions & 1 deletion codegen/templates/generated.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}

type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}

type ResolverRoot interface {
Expand All @@ -35,7 +37,21 @@ type DirectiveRoot struct {
{{ end }}
}

{{- range $object := .Objects -}}
type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if $object.HasComplexity }}
{{ $object.GoType }} struct {
{{ range $field := $object.Fields }}
{{ if $field.CustomComplexity }}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{ end }}
}
{{ end }}
{{ end }}
}

{{ range $object := .Objects -}}
{{ if $object.HasResolvers }}
type {{$object.GQLType}}Resolver interface {
{{ range $field := $object.Fields -}}
Expand All @@ -48,12 +64,42 @@ type DirectiveRoot struct {
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}

func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}

func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ range $field := $object.Fields }}
{{ if $field.CustomComplexity }}
case "{{$object.GQLType}}.{{$field.GQLName}}":
if e.complexity.{{$object.GoType}}.{{$field.GoFieldName}} == nil {
break
}
{{ if . }}args := map[string]interface{}{} {{end}}
{{ range $i, $arg := $field.Args }}
var arg{{$i}} {{$arg.Signature }}
if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {
var err error
{{$arg.Unmarshal (print "arg" $i) "tmp" }}
if err != nil {
return 0, false
}
}
args[{{$arg.GQLName|quote}}] = arg{{$i}}
{{ end }}
return e.complexity.{{$object.GoType}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true
{{ end }}
{{ end }}
{{ end }}
}
return 0, false
}

func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
{{- if .QueryRoot }}
ec := executionContext{graphql.GetRequestContext(ctx), e}
Expand Down
1 change: 1 addition & 0 deletions graphql/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type ExecutableSchema interface {
Schema() *ast.Schema

Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
Query(ctx context.Context, op *ast.OperationDefinition) *Response
Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
Expand Down
28 changes: 23 additions & 5 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"strings"

"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser"
Expand All @@ -23,11 +24,12 @@ type params struct {
}

type Config struct {
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
complexityLimit int
}

func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext {
Expand Down Expand Up @@ -74,6 +76,14 @@ func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
}
}

// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
// If a query is submitted that exceeds the limit, a 422 status code will be returned.
func ComplexityLimit(limit int) Option {
return func(cfg *Config) {
cfg.complexityLimit = limit
}
}

// ResolverMiddleware allows you to define a function that will be called around every resolver,
// useful for tracing and logging.
func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
Expand Down Expand Up @@ -184,6 +194,14 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
}
}()

if cfg.complexityLimit > 0 {
queryComplexity := complexity.Calculate(exec, op, vars)
if queryComplexity > cfg.complexityLimit {
sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit)
return
}
}

switch op.Operation {
case ast.Query:
b, err := json.Marshal(exec.Query(ctx, op))
Expand Down
4 changes: 4 additions & 0 deletions handler/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func (e *executableSchemaStub) Schema() *ast.Schema {
`})
}

func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) {
return 0, false
}

func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
}
Expand Down

0 comments on commit 238a7e2

Please sign in to comment.