Skip to content

Commit

Permalink
Bind to types.Type directly to remove TypeImplementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Jan 8, 2019
1 parent 70c852e commit 950ff42
Show file tree
Hide file tree
Showing 30 changed files with 276 additions and 273 deletions.
41 changes: 28 additions & 13 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,23 @@ type ServerBuild struct {

// Create a list of models that need to be generated
func (g *Generator) models() (*ModelBuild, error) {
namedTypes := g.buildNamedTypes()

progLoader := g.newLoaderWithoutErrors()

prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

g.bindTypes(namedTypes, g.Model.Dir(), prog)
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

models, err := g.buildModels(namedTypes, prog)
if err != nil {
Expand All @@ -77,11 +84,16 @@ func (g *Generator) resolver() (*ResolverBuild, error) {
return nil, err
}

destDir := g.Resolver.Dir()

namedTypes := g.buildNamedTypes()
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

g.bindTypes(namedTypes, destDir, prog)
directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

objects, err := g.buildObjects(namedTypes, prog)
if err != nil {
Expand Down Expand Up @@ -109,26 +121,29 @@ func (g *Generator) server(destDir string) *ServerBuild {

// bind a schema together with some code to generate a Build
func (g *Generator) bind() (*Build, error) {
namedTypes := g.buildNamedTypes()

progLoader := g.newLoaderWithoutErrors()
prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

g.bindTypes(namedTypes, g.Exec.Dir(), prog)
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

objects, err := g.buildObjects(namedTypes, prog)
directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

inputs, err := g.buildInputs(namedTypes, prog)
objects, err := g.buildObjects(namedTypes, prog)
if err != nil {
return nil, err
}
directives, err := g.buildDirectives(namedTypes)

inputs, err := g.buildInputs(namedTypes, prog)
if err != nil {
return nil, err
}
Expand Down
11 changes: 11 additions & 0 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"sort"
"strings"

"go/types"

"github.com/99designs/gqlgen/internal/gopath"
"github.com/pkg/errors"
"github.com/vektah/gqlparser"
Expand Down Expand Up @@ -168,6 +170,10 @@ func (c *PackageConfig) Check() error {
return c.normalize()
}

func (c *PackageConfig) Pkg() *types.Package {
return types.NewPackage(c.ImportPath(), c.Dir())
}

func (c *PackageConfig) IsDefined() bool {
return c.Filename != ""
}
Expand Down Expand Up @@ -198,6 +204,11 @@ func (tm TypeMap) Exists(typeName string) bool {
return ok
}

func (tm TypeMap) UserDefined(typeName string) bool {
m, ok := tm[typeName]
return ok && m.Model != ""
}

func (tm TypeMap) Check() error {
for typeName, entry := range tm {
if strings.LastIndex(entry.Model, ".") < strings.LastIndex(entry.Model, "/") {
Expand Down
11 changes: 7 additions & 4 deletions codegen/enum_build.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package codegen

import (
"go/types"
"sort"
"strings"

"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/ast"
)

func (g *Generator) buildEnums(types NamedTypes) []Enum {
func (g *Generator) buildEnums(ts NamedTypes) []Enum {
var enums []Enum

for _, typ := range g.schema.Types {
namedType := types[typ.Name]
if typ.Kind != ast.Enum || strings.HasPrefix(typ.Name, "__") || namedType.IsUserDefined {
namedType := ts[typ.Name]
if typ.Kind != ast.Enum || strings.HasPrefix(typ.Name, "__") || g.Models.UserDefined(typ.Name) {
continue
}

Expand All @@ -27,7 +28,9 @@ func (g *Generator) buildEnums(types NamedTypes) []Enum {
Values: values,
Description: typ.Description,
}
enum.GoType = templates.ToCamel(enum.GQLType)

enum.GoType = types.NewNamed(types.NewTypeName(0, g.Config.Model.Pkg(), templates.ToCamel(enum.GQLType), nil), nil, nil)

enums = append(enums, enum)
}

Expand Down
13 changes: 3 additions & 10 deletions codegen/generator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"go/types"
"log"
"os"
"path/filepath"
Expand Down Expand Up @@ -40,14 +41,6 @@ func (g *Generator) Generate() error {
_ = syscall.Unlink(g.Exec.Filename)
_ = syscall.Unlink(g.Model.Filename)

namedTypes := g.buildNamedTypes()

directives, err := g.buildDirectives(namedTypes)
if err != nil {
return err
}
g.Directives = directives

modelsBuild, err := g.models()
if err != nil {
return errors.Wrap(err, "model plan failed")
Expand All @@ -59,13 +52,13 @@ func (g *Generator) Generate() error {

for _, model := range modelsBuild.Models {
modelCfg := g.Models[model.GQLType]
modelCfg.Model = g.Model.ImportPath() + "." + model.GoType
modelCfg.Model = types.TypeString(model.GoType, nil)
g.Models[model.GQLType] = modelCfg
}

for _, enum := range modelsBuild.Enums {
modelCfg := g.Models[enum.GQLType]
modelCfg.Model = g.Model.ImportPath() + "." + enum.GoType
modelCfg.Model = types.TypeString(enum.GoType, nil)
g.Models[enum.GQLType] = modelCfg
}
}
Expand Down
29 changes: 4 additions & 25 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package codegen

import (
"go/types"
"sort"

"go/types"

"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/loader"
Expand All @@ -20,13 +21,8 @@ func (g *Generator) buildInputs(namedTypes NamedTypes, prog *loader.Program) (Ob
return nil, err
}

def, err := findGoType(prog, input.Package, input.GoType)
if err != nil {
return nil, errors.Wrap(err, "cannot find type")
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindErrs := bindObject(def.Type(), input, g.StructTag)
if _, isMap := input.GoType.(*types.Map); !isMap {
bindErrs := bindObject(input, g.StructTag)
if len(bindErrs) > 0 {
return nil, bindErrs
}
Expand Down Expand Up @@ -88,20 +84,3 @@ func (g *Generator) buildInput(types NamedTypes, typ *ast.Definition) (*Object,

return obj, nil
}

// if user has implemented an UnmarshalGQL method on the input type manually, use it
// otherwise we will generate one.
func buildInputMarshaler(typ *ast.Definition, def types.Object) *TypeImplementation {
switch def := def.(type) {
case *types.TypeName:
namedType := def.Type().(*types.Named)
for i := 0; i < namedType.NumMethods(); i++ {
method := namedType.Method(i)
if method.Name() == "UnmarshalGQL" {
return nil
}
}
}

return &TypeImplementation{GoType: typ.Name}
}
4 changes: 2 additions & 2 deletions codegen/interface_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func (g *Generator) buildInterface(types NamedTypes, typ *ast.Definition, prog *
}

func (g *Generator) isValueReceiver(intf *TypeDefinition, implementor *TypeDefinition, prog *loader.Program) bool {
interfaceType, err := findGoInterface(prog, intf.Package, intf.GoType)
interfaceType, err := findGoInterface(intf.GoType)
if interfaceType == nil || err != nil {
return true
}

implementorType, err := findGoNamedType(prog, implementor.Package, implementor.GoType)
implementorType, err := findGoNamedType(implementor.GoType)
if implementorType == nil || err != nil {
return true
}
Expand Down
17 changes: 4 additions & 13 deletions codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ func (g *Generator) buildModels(types NamedTypes, prog *loader.Program) ([]Model

for _, typ := range g.schema.Types {
var model Model
if g.Models.UserDefined(typ.Name) {
continue
}
switch typ.Kind {
case ast.Object:
obj, err := g.buildObject(types, typ)
if err != nil {
return nil, err
}
if obj.Root || obj.IsUserDefined {
if obj.Root {
continue
}
model = g.obj2Model(obj)
Expand All @@ -27,15 +30,9 @@ func (g *Generator) buildModels(types NamedTypes, prog *loader.Program) ([]Model
if err != nil {
return nil, err
}
if obj.IsUserDefined {
continue
}
model = g.obj2Model(obj)
case ast.Interface, ast.Union:
intf := g.buildInterface(types, typ, prog)
if intf.IsUserDefined {
continue
}
model = int2Model(intf)
default:
continue
Expand All @@ -59,9 +56,6 @@ func (g *Generator) obj2Model(obj *Object) Model {
Fields: []ModelField{},
}

model.GoType = ucFirst(obj.GQLType)
model.Marshaler = &TypeImplementation{GoType: obj.GoType}

for i := range obj.Fields {
field := &obj.Fields[i]
mf := ModelField{TypeReference: field.TypeReference, GQLName: field.GQLName}
Expand All @@ -84,8 +78,5 @@ func int2Model(obj *Interface) Model {
Fields: []ModelField{},
}

model.GoType = ucFirst(obj.GQLType)
model.Marshaler = &TypeImplementation{GoType: obj.GoType}

return model
}
11 changes: 7 additions & 4 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"text/template"
"unicode"

"go/types"

"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/ast"
)

Expand All @@ -25,7 +28,7 @@ type Object struct {
Fields []Field
Satisfies []string
Implements []*TypeDefinition
ResolverInterface *TypeImplementation
ResolverInterface types.Type
Root bool
DisableConcurrency bool
Stream bool
Expand Down Expand Up @@ -169,7 +172,7 @@ func (f *Field) ShortResolverDeclaration() string {
res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())

if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", f.Object.FullName())
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.GoType))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
Expand All @@ -191,7 +194,7 @@ func (f *Field) ResolverDeclaration() string {
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())

if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", f.Object.FullName())
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.GoType))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
Expand Down Expand Up @@ -361,7 +364,7 @@ func (os Objects) ByName(name string) *Object {

func tpl(tpl string, vars map[string]interface{}) string {
b := &bytes.Buffer{}
err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
err := template.Must(template.New("inline").Funcs(templates.Funcs()).Parse(tpl)).Execute(b, vars)
if err != nil {
panic(err)
}
Expand Down
Loading

0 comments on commit 950ff42

Please sign in to comment.