diff --git a/codegen/object.go b/codegen/object.go index 9bad02a0ca..3c5c14f4d9 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -35,8 +35,9 @@ type Field struct { type FieldArgument struct { *Type - GQLName string // The name of the argument in graphql - Object *Object // A link back to the parent object + GQLName string // The name of the argument in graphql + Object *Object // A link back to the parent object + Default interface{} // The default value } type Objects []*Object diff --git a/codegen/object_build.go b/codegen/object_build.go index 0aa170fee5..08c82ea4a0 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -98,11 +98,17 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object { for _, field := range typ.Fields { var args []FieldArgument for _, arg := range field.Args { - args = append(args, FieldArgument{ + newArg := FieldArgument{ GQLName: arg.Name.Name, Type: types.getType(arg.Type), Object: obj, - }) + } + + if arg.Default != nil { + newArg.Default = arg.Default.Value(nil) + newArg.StripPtr() + } + args = append(args, newArg) } obj.Fields = append(obj.Fields, Field{ diff --git a/codegen/templates/args.gotpl b/codegen/templates/args.gotpl index e2d253d0ba..79f3971cbc 100644 --- a/codegen/templates/args.gotpl +++ b/codegen/templates/args.gotpl @@ -11,5 +11,18 @@ return graphql.Null {{- end }} } + } {{ if $arg.Default }} else { + tmp := {{ $arg.Default | dump }} + var err error + {{$arg.Unmarshal (print "arg" $i) "tmp" }} + if err != nil { + ec.Error(err) + {{- if $arg.Object.Stream }} + return nil + {{- else }} + return graphql.Null + {{- end }} + } } + {{end }} {{- end -}} diff --git a/codegen/templates/data.go b/codegen/templates/data.go index 9c86ba3253..e6f8b10776 100644 --- a/codegen/templates/data.go +++ b/codegen/templates/data.go @@ -1,7 +1,7 @@ package templates var data = map[string]string{ - "args.gotpl": "\t{{- range $i, $arg := . }}\n\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\tif tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok {\n\t\t\tvar err error\n\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\tif err != nil {\n\t\t\t\tec.Error(err)\n\t\t\t\t{{- if $arg.Object.Stream }}\n\t\t\t\t\treturn nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\t{{- end -}}\n", + "args.gotpl": "\t{{- range $i, $arg := . }}\n\t\tvar arg{{$i}} {{$arg.Signature }}\n\t\tif tmp, ok := field.Args[{{$arg.GQLName|quote}}]; ok {\n\t\t\tvar err error\n\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\tif err != nil {\n\t\t\t\tec.Error(err)\n\t\t\t\t{{- if $arg.Object.Stream }}\n\t\t\t\t\treturn nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t{{- end }}\n\t\t\t}\n\t\t} {{ if $arg.Default }} else {\n\t\t\ttmp := {{ $arg.Default | dump }}\n\t\t\tvar err error\n\t\t\t{{$arg.Unmarshal (print \"arg\" $i) \"tmp\" }}\n\t\t\tif err != nil {\n\t\t\t\tec.Error(err)\n\t\t\t\t{{- if $arg.Object.Stream }}\n\t\t\t\t\treturn nil\n\t\t\t\t{{- else }}\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\t\t{{end }}\n\t{{- end -}}\n", "field.gotpl": "{{ $field := . }}\n{{ $object := $field.Object }}\n\n{{- if $object.Stream }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler {\n\t\t{{- template \"args.gotpl\" $field.Args }}\n\t\tresults, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})\n\t\tif err != nil {\n\t\t\tec.Error(err)\n\t\t\treturn nil\n\t\t}\n\t\treturn func() graphql.Marshaler {\n\t\t\tres, ok := <-results\n\t\t\tif !ok {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tvar out graphql.OrderedMap\n\t\t\tout.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())\n\t\t\treturn &out\n\t\t}\n\t}\n{{ else }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {\n\t\t{{- template \"args.gotpl\" $field.Args }}\n\n\t\t{{- if $field.IsConcurrent }}\n\t\t\treturn graphql.Defer(func() graphql.Marshaler {\n\t\t{{- end }}\n\n\t\t\t{{- if $field.GoVarName }}\n\t\t\t\tres := obj.{{$field.GoVarName}}\n\t\t\t{{- else if $field.GoMethodName }}\n\t\t\t\t{{- if $field.NoErr }}\n\t\t\t\t\tres := {{$field.GoMethodName}}({{ $field.CallArgs }})\n\t\t\t\t{{- else }}\n\t\t\t\t\tres, err := {{$field.GoMethodName}}({{ $field.CallArgs }})\n\t\t\t\t\tif err != nil {\n\t\t\t\t\t\tec.Error(err)\n\t\t\t\t\t\treturn graphql.Null\n\t\t\t\t\t}\n\t\t\t\t{{- end }}\n\t\t\t{{- else }}\n\t\t\t\tres, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})\n\t\t\t\tif err != nil {\n\t\t\t\t\tec.Error(err)\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t{{ $field.WriteJson }}\n\t\t{{- if $field.IsConcurrent }}\n\t\t\t})\n\t\t{{- end }}\n\t}\n{{ end }}\n", "file.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\nfunc MakeExecutableSchema(resolvers Resolvers) graphql.ExecutableSchema {\n\treturn &executableSchema{resolvers}\n}\n\ntype Resolvers interface {\n{{- range $object := .Objects -}}\n\t{{ range $field := $object.Fields -}}\n\t\t{{ $field.ResolverDeclaration }}\n\t{{ end }}\n{{- end }}\n}\n\n{{ range $model := .Models }}\n\t{{ template \"model.gotpl\" $model }}\n{{- end}}\n\ntype executableSchema struct {\n\tresolvers Resolvers\n}\n\nfunc (e *executableSchema) Schema() *schema.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}\n\n\t\tdata := ec._{{.QueryRoot.GQLType}}(op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"queries are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}\n\n\t\tdata := ec._{{.MutationRoot.GQLType}}(op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"mutations are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf.Reset()\n\t\t\tdata := next()\n\t\t\tif data == nil {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tdata.MarshalGQL(&buf)\n\n\t\t\terrs := ec.Errors\n\t\t\tec.Errors = nil\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf.Bytes(),\n\t\t\t\tErrors: errs,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: \"subscriptions are not supported\"} }})\n\t{{- end }}\n}\n\ntype executionContext struct {\n\terrors.Builder\n\tresolvers Resolvers\n\tvariables map[string]interface{}\n\tdoc *query.Document\n\tctx context.Context\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nvar parsedSchema = schema.MustParse({{.SchemaRaw|quote}})\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\tt := parsedSchema.Resolve(name)\n\tif t == nil {\n\t\treturn nil\n\t}\n\treturn introspection.WrapType(t)\n}\n", "input.gotpl": "\t{{- if .IsMarshaled }}\n\tfunc Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) {\n\t\tvar it {{.FullName}}\n\n\t\tfor k, v := range v.(map[string]interface{}) {\n\t\t\tswitch k {\n\t\t\t{{- range $field := .Fields }}\n\t\t\tcase {{$field.GQLName|quote}}:\n\t\t\t\tvar err error\n\t\t\t\t{{ $field.Unmarshal (print \"it.\" $field.GoVarName) \"v\" }}\n\t\t\t\tif err != nil {\n\t\t\t\t\treturn it, err\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\n\t\treturn it, nil\n\t}\n\t{{- end }}\n", diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index 49eec9efcf..54a9b1b2c8 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -4,7 +4,9 @@ package templates import ( "bytes" + "fmt" "strconv" + "strings" "text/template" "unicode" @@ -16,6 +18,7 @@ func Run(e *codegen.Build) (*bytes.Buffer, error) { "ucFirst": ucFirst, "lcFirst": lcFirst, "quote": strconv.Quote, + "dump": dump, }) for filename, data := range data { @@ -52,3 +55,36 @@ func lcFirst(s string) string { r[0] = unicode.ToLower(r[0]) return string(r) } + +func dump(val interface{}) string { + switch val := val.(type) { + case int: + return strconv.Itoa(val) + case float64: + return fmt.Sprintf("%f", val) + case string: + return strconv.Quote(val) + case bool: + return strconv.FormatBool(val) + case nil: + return "nil" + case []interface{}: + var parts []string + for _, part := range val { + parts = append(parts, dump(part)) + } + return "[]interface{}{" + strings.Join(parts, ",") + "}" + case map[string]interface{}: + buf := bytes.Buffer{} + buf.WriteString("map[string]interface{}{") + for key, data := range val { + buf.WriteString(strconv.Quote(key)) + buf.WriteString(":") + buf.WriteString(dump(data)) + } + buf.WriteString("}") + return buf.String() + default: + panic(fmt.Errorf("unsupported type %T", val)) + } +} diff --git a/codegen/type.go b/codegen/type.go index 6eb40f9e82..1e6e76d337 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -62,6 +62,13 @@ func (t Type) IsPtr() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr } +func (t *Type) StripPtr() { + if !t.IsPtr() { + return + } + t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] +} + func (t Type) IsSlice() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modList } diff --git a/codegen/type_build.go b/codegen/type_build.go index 86be1265a5..50fa32dc83 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -101,8 +101,8 @@ func (n NamedTypes) getType(t common.Type) *Type { Modifiers: modifiers, } - if t.IsInterface && t.Modifiers[len(t.Modifiers)-1] == modPtr { - t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] + if t.IsInterface { + t.StripPtr() } return t diff --git a/example/scalars/generated.go b/example/scalars/generated.go index 96d44cb5ed..2d37a48824 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -118,12 +118,22 @@ func (ec *executionContext) _Query_search(field graphql.CollectedField) graphql. if tmp, ok := field.Args["input"]; ok { var err error + arg0, err = UnmarshalSearchArgs(tmp) + if err != nil { + ec.Error(err) + return graphql.Null + } + } else { + tmp := map[string]interface{}{"location": "37,144"} + var err error + arg0, err = UnmarshalSearchArgs(tmp) if err != nil { ec.Error(err) return graphql.Null } } + return graphql.Defer(func() graphql.Marshaler { res, err := ec.resolvers.Query_search(ec.ctx, arg0) if err != nil { @@ -751,7 +761,7 @@ func UnmarshalSearchArgs(v interface{}) (SearchArgs, error) { return it, nil } -var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n") +var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs = {location: \"37,144\"}): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n") func (ec *executionContext) introspectSchema() *introspection.Schema { return introspection.WrapSchema(parsedSchema) diff --git a/example/scalars/scalar_test.go b/example/scalars/scalar_test.go index 137e01a519..835cb434d2 100644 --- a/example/scalars/scalar_test.go +++ b/example/scalars/scalar_test.go @@ -43,6 +43,14 @@ func TestScalars(t *testing.T) { require.Equal(t, int64(666), resp.Search[0].Created) }) + t.Run("default search location", func(t *testing.T) { + var resp struct{ Search []RawUser } + + err := c.Post(`{ search { location } }`, &resp) + require.NoError(t, err) + require.Equal(t, "37,144", resp.Search[0].Location) + }) + t.Run("test custom error messages", func(t *testing.T) { var resp struct{ Search []RawUser } diff --git a/example/scalars/schema.graphql b/example/scalars/schema.graphql index 18ac9a60b0..f841391ac4 100644 --- a/example/scalars/schema.graphql +++ b/example/scalars/schema.graphql @@ -4,7 +4,7 @@ schema { type Query { user(id: ID!): User - search(input: SearchArgs!): [User!]! + search(input: SearchArgs = {location: "37,144"}): [User!]! } type User { diff --git a/example/starwars/generated.go b/example/starwars/generated.go index 586d8e0a35..1df185613f 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -33,7 +33,7 @@ type Resolvers interface { Human_starships(ctx context.Context, obj *Human) ([]Starship, error) Mutation_createReview(ctx context.Context, episode string, review Review) (*Review, error) - Query_hero(ctx context.Context, episode *string) (Character, error) + Query_hero(ctx context.Context, episode string) (Character, error) Query_reviews(ctx context.Context, episode string, since *time.Time) ([]Review, error) Query_search(ctx context.Context, text string) ([]SearchResult, error) Query_character(ctx context.Context, id string) (Character, error) @@ -347,12 +347,22 @@ func (ec *executionContext) _Human_height(field graphql.CollectedField, obj *Hum if tmp, ok := field.Args["unit"]; ok { var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + ec.Error(err) + return graphql.Null + } + } else { + tmp := "METER" + var err error + arg0, err = graphql.UnmarshalString(tmp) if err != nil { ec.Error(err) return graphql.Null } } + res := obj.Height(arg0) return graphql.MarshalFloat(res) } @@ -570,18 +580,26 @@ func (ec *executionContext) _Query(sel []query.Selection) graphql.Marshaler { } func (ec *executionContext) _Query_hero(field graphql.CollectedField) graphql.Marshaler { - var arg0 *string + var arg0 string if tmp, ok := field.Args["episode"]; ok { var err error - var ptr1 string - ptr1, err = graphql.UnmarshalString(tmp) - arg0 = &ptr1 + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + ec.Error(err) + return graphql.Null + } + } else { + tmp := "NEWHOPE" + var err error + + arg0, err = graphql.UnmarshalString(tmp) if err != nil { ec.Error(err) return graphql.Null } } + return graphql.Defer(func() graphql.Marshaler { res, err := ec.resolvers.Query_hero(ec.ctx, arg0) if err != nil { @@ -860,12 +878,22 @@ func (ec *executionContext) _Starship_length(field graphql.CollectedField, obj * if tmp, ok := field.Args["unit"]; ok { var err error + arg0, err = graphql.UnmarshalString(tmp) + if err != nil { + ec.Error(err) + return graphql.Null + } + } else { + tmp := "METER" + var err error + arg0, err = graphql.UnmarshalString(tmp) if err != nil { ec.Error(err) return graphql.Null } } + res := obj.Length(arg0) return graphql.MarshalFloat(res) } diff --git a/example/starwars/resolvers.go b/example/starwars/resolvers.go index 04c97068e6..08475d5365 100644 --- a/example/starwars/resolvers.go +++ b/example/starwars/resolvers.go @@ -85,8 +85,8 @@ func (r *Resolver) Mutation_createReview(ctx context.Context, episode string, re return &review, nil } -func (r *Resolver) Query_hero(ctx context.Context, episode *string) (Character, error) { - if episode != nil && *episode == "EMPIRE" { +func (r *Resolver) Query_hero(ctx context.Context, episode string) (Character, error) { + if episode == "EMPIRE" { return r.humans["1000"], nil } return r.droid["2001"], nil diff --git a/example/starwars/starwars_test.go b/example/starwars/starwars_test.go index 04e9646265..8d7a8d2ad9 100644 --- a/example/starwars/starwars_test.go +++ b/example/starwars/starwars_test.go @@ -84,6 +84,17 @@ func TestStarwars(t *testing.T) { require.Equal(t, 1.72, resp.Hero.Height) }) + t.Run("default hero episode", func(t *testing.T) { + var resp struct { + Hero struct { + Name string + } + } + c.MustPost(`{ hero { ... on Droid { name } } }`, &resp) + + require.Equal(t, "R2-D2", resp.Hero.Name) + }) + t.Run("friends", func(t *testing.T) { var resp struct { Human struct {