Skip to content

Commit

Permalink
bind to types.Types in field / arg references too
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Jan 8, 2019
1 parent 38add2c commit 8298acb
Show file tree
Hide file tree
Showing 17 changed files with 146 additions and 179 deletions.
28 changes: 15 additions & 13 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,21 @@ func (cfg *Config) normalize() error {
}

builtins := TypeMap{
"__Directive": {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
"__Type": {Model: "github.com/99designs/gqlgen/graphql/introspection.Type"},
"__Field": {Model: "github.com/99designs/gqlgen/graphql/introspection.Field"},
"__EnumValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.EnumValue"},
"__InputValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.InputValue"},
"__Schema": {Model: "github.com/99designs/gqlgen/graphql/introspection.Schema"},
"Int": {Model: "github.com/99designs/gqlgen/graphql.Int"},
"Float": {Model: "github.com/99designs/gqlgen/graphql.Float"},
"String": {Model: "github.com/99designs/gqlgen/graphql.String"},
"Boolean": {Model: "github.com/99designs/gqlgen/graphql.Boolean"},
"ID": {Model: "github.com/99designs/gqlgen/graphql.ID"},
"Time": {Model: "github.com/99designs/gqlgen/graphql.Time"},
"Map": {Model: "github.com/99designs/gqlgen/graphql.Map"},
"__Directive": {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
"__DirectiveLocation": {Model: "github.com/99designs/gqlgen/graphql.String"},
"__Type": {Model: "github.com/99designs/gqlgen/graphql/introspection.Type"},
"__TypeKind": {Model: "github.com/99designs/gqlgen/graphql.String"},
"__Field": {Model: "github.com/99designs/gqlgen/graphql/introspection.Field"},
"__EnumValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.EnumValue"},
"__InputValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.InputValue"},
"__Schema": {Model: "github.com/99designs/gqlgen/graphql/introspection.Schema"},
"Int": {Model: "github.com/99designs/gqlgen/graphql.Int"},
"Float": {Model: "github.com/99designs/gqlgen/graphql.Float"},
"String": {Model: "github.com/99designs/gqlgen/graphql.String"},
"Boolean": {Model: "github.com/99designs/gqlgen/graphql.Boolean"},
"ID": {Model: "github.com/99designs/gqlgen/graphql.ID"},
"Time": {Model: "github.com/99designs/gqlgen/graphql.Time"},
"Map": {Model: "github.com/99designs/gqlgen/graphql.Map"},
}

if cfg.Models == nil {
Expand Down
4 changes: 2 additions & 2 deletions codegen/directive.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (d *Directive) CallArgs() string {
args := []string{"ctx", "obj", "n"}

for _, arg := range d.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+templates.CurrentImports.LookupType(arg.GoType)+")")
}

return strings.Join(args, ", ")
Expand Down Expand Up @@ -56,7 +56,7 @@ func (d *Directive) Declaration() string {
res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"

for _, arg := range d.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
}

res += ") (res interface{}, err error)"
Expand Down
2 changes: 1 addition & 1 deletion codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (g *Generator) buildModels(types NamedTypes, prog *loader.Program) ([]Model
}
switch typ.Kind {
case ast.Object:
obj, err := g.buildObject(types, typ)
obj, err := g.buildObject(prog, types, typ)
if err != nil {
return nil, err
}
Expand Down
47 changes: 26 additions & 21 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ func (f *Field) ShortResolverDeclaration() string {
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Definition.GoType))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
}

result := f.Signature()
result := templates.CurrentImports.LookupType(f.GoType)
if f.Object.Stream {
result = "<-chan " + result
}
Expand All @@ -196,10 +196,10 @@ func (f *Field) ResolverDeclaration() string {
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Definition.GoType))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
}

result := f.Signature()
result := templates.CurrentImports.LookupType(f.GoType)
if f.Object.Stream {
result = "<-chan " + result
}
Expand All @@ -211,7 +211,7 @@ func (f *Field) ResolverDeclaration() string {
func (f *Field) ComplexitySignature() string {
res := fmt.Sprintf("func(childComplexity int")
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
}
res += ") int"
return res
Expand All @@ -220,7 +220,7 @@ func (f *Field) ComplexitySignature() string {
func (f *Field) ComplexityArgs() string {
var args []string
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+templates.CurrentImports.LookupType(arg.GoType)+")")
}

return strings.Join(args, ", ")
Expand All @@ -242,20 +242,20 @@ func (f *Field) CallArgs() string {
}

for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+templates.CurrentImports.LookupType(arg.GoType)+")")
}

return strings.Join(args, ", ")
}

// should be in the template, but its recursive and has a bunch of args
func (f *Field) WriteJson() string {
return f.doWriteJson("res", f.TypeReference.Modifiers, f.ASTType, false, 1)
return f.doWriteJson("res", f.GoType, f.ASTType, false, 1)
}

func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
switch {
case len(remainingMods) > 0 && remainingMods[0] == modPtr:
func (f *Field) doWriteJson(val string, destType types.Type, astType *ast.Type, isPtr bool, depth int) string {
switch destType := destType.(type) {
case *types.Pointer:
return tpl(`
if {{.val}} == nil {
{{- if .nonNull }}
Expand All @@ -268,18 +268,22 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ
{{.next }}`, map[string]interface{}{
"val": val,
"nonNull": astType.NonNull,
"next": f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
"next": f.doWriteJson(val, destType.Elem(), astType, true, depth+1),
})

case len(remainingMods) > 0 && remainingMods[0] == modList:
case *types.Slice:
if isPtr {
val = "*" + val
}
var arr = "arr" + strconv.Itoa(depth)
var index = "idx" + strconv.Itoa(depth)
var usePtr bool
if len(remainingMods) == 1 && !isPtr {
usePtr = true
if !isPtr {
switch destType.Elem().(type) {
case *types.Pointer, *types.Array:
default:
usePtr = true
}
}

return tpl(`
Expand Down Expand Up @@ -327,16 +331,17 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ
"arrayLen": len(val),
"isScalar": f.Definition.IsScalar,
"usePtr": usePtr,
"next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
"next": f.doWriteJson(val+"["+index+"]", destType.Elem(), astType.Elem, false, depth+1),
})

case f.Definition.IsScalar:
if isPtr {
val = "*" + val
default:
if f.Definition.IsScalar {
if isPtr {
val = "*" + val
}
return f.Marshal(val)
}
return f.Marshal(val)

default:
if !isPtr {
val = "&" + val
}
Expand Down
20 changes: 15 additions & 5 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (g *Generator) buildObjects(ts NamedTypes, prog *loader.Program) (Objects,
continue
}

obj, err := g.buildObject(ts, typ)
obj, err := g.buildObject(prog, ts, typ)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -80,7 +80,7 @@ func sanitizeArgName(name string) string {
return name
}

func (g *Generator) buildObject(ts NamedTypes, typ *ast.Definition) (*Object, error) {
func (g *Generator) buildObject(prog *loader.Program, ts NamedTypes, typ *ast.Definition) (*Object, error) {
obj := &Object{Definition: ts[typ.Name]}
typeEntry, entryExists := g.Models[typ.Name]

Expand Down Expand Up @@ -109,8 +109,13 @@ func (g *Generator) buildObject(ts NamedTypes, typ *ast.Definition) (*Object, er

for _, field := range typ.Fields {
if typ == g.schema.Query && field.Name == "__type" {
schemaType, err := findGoType(prog, "github.com/99designs/gqlgen/graphql/introspection", "Schema")
if err != nil {
return nil, errors.Wrap(err, "unable to find root schema introspection type")
}

obj.Fields = append(obj.Fields, Field{
TypeReference: &TypeReference{ts["__Schema"], []string{modPtr}, ast.NamedType("__Schema", nil)},
TypeReference: &TypeReference{ts["__Schema"], types.NewPointer(schemaType.Type()), ast.NamedType("__Schema", nil)},
GQLName: "__schema",
GoFieldType: GoFieldMethod,
GoReceiverName: "ec",
Expand All @@ -121,14 +126,19 @@ func (g *Generator) buildObject(ts NamedTypes, typ *ast.Definition) (*Object, er
continue
}
if typ == g.schema.Query && field.Name == "__schema" {
typeType, err := findGoType(prog, "github.com/99designs/gqlgen/graphql/introspection", "Type")
if err != nil {
return nil, errors.Wrap(err, "unable to find root schema introspection type")
}

obj.Fields = append(obj.Fields, Field{
TypeReference: &TypeReference{ts["__Type"], []string{modPtr}, ast.NamedType("__Schema", nil)},
TypeReference: &TypeReference{ts["__Type"], types.NewPointer(typeType.Type()), ast.NamedType("__Schema", nil)},
GQLName: "__type",
GoFieldType: GoFieldMethod,
GoReceiverName: "ec",
GoFieldName: "introspectType",
Args: []FieldArgument{
{GQLName: "name", TypeReference: &TypeReference{ts["String"], []string{}, ast.NamedType("String", nil)}, Object: &Object{}},
{GQLName: "name", TypeReference: &TypeReference{ts["String"], types.Typ[types.String], ast.NamedType("String", nil)}, Object: &Object{}},
},
Object: obj,
})
Expand Down
6 changes: 3 additions & 3 deletions codegen/templates/args.gotpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
args := map[string]interface{}{}
{{- range $i, $arg := . }}
var arg{{$i}} {{$arg.Signature }}
var arg{{$i}} {{$arg.GoType | ref }}
if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {
{{- if $arg.Directives }}
argm{{$i}}, err := chainFieldMiddleware([]graphql.FieldMiddleware{
Expand All @@ -25,10 +25,10 @@
if err != nil {
return nil, err
}
if data, ok := argm{{$i}}.({{$arg.Signature }}); ok {
if data, ok := argm{{$i}}.({{$arg.GoType | ref }}); ok {
arg{{$i}} = data
} else {
return nil, errors.New("expect {{$arg.Signature }}")
return nil, errors.New("expect {{$arg.GoType | ref }}")
}
{{- else }}
var err error
Expand Down

0 comments on commit 8298acb

Please sign in to comment.