Skip to content

Commit

Permalink
Use ast definition directly, instead of copying
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Jan 8, 2019
1 parent 8298acb commit afc773b
Show file tree
Hide file tree
Showing 21 changed files with 90 additions and 102 deletions.
2 changes: 1 addition & 1 deletion codegen/directive_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (g *Generator) buildDirectives(types NamedTypes) (map[string]*Directive, er
GoVarName: sanitizeArgName(arg.Name),
}

if !newArg.TypeReference.Definition.IsInput && !newArg.TypeReference.Definition.IsScalar {
if !newArg.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf("%s cannot be used as argument of directive %s(%s) only input and scalar types are allowed", arg.Type, dir.Name, arg.Name)
}

Expand Down
4 changes: 2 additions & 2 deletions codegen/enum_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ func (g *Generator) buildEnums(ts NamedTypes) []Enum {
Description: typ.Description,
}

enum.Definition.GoType = types.NewNamed(types.NewTypeName(0, g.Config.Model.Pkg(), templates.ToCamel(enum.Definition.GQLType), nil), nil, nil)
enum.Definition.GoType = types.NewNamed(types.NewTypeName(0, g.Config.Model.Pkg(), templates.ToCamel(enum.Definition.GQLDefinition.Name), nil), nil, nil)

enums = append(enums, enum)
}

sort.Slice(enums, func(i, j int) bool {
return enums[i].Definition.GQLType < enums[j].Definition.GQLType
return enums[i].Definition.GQLDefinition.Name < enums[j].Definition.GQLDefinition.Name
})

return enums
Expand Down
8 changes: 4 additions & 4 deletions codegen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (g *Generator) Generate() error {
}

for _, model := range modelsBuild.Models {
modelCfg := g.Models[model.Definition.GQLType]
modelCfg := g.Models[model.Definition.GQLDefinition.Name]
modelCfg.Model = types.TypeString(model.Definition.GoType, nil)
g.Models[model.Definition.GQLType] = modelCfg
g.Models[model.Definition.GQLDefinition.Name] = modelCfg
}

for _, enum := range modelsBuild.Enums {
modelCfg := g.Models[enum.Definition.GQLType]
modelCfg := g.Models[enum.Definition.GQLDefinition.Name]
modelCfg.Model = types.TypeString(enum.Definition.GoType, nil)
g.Models[enum.Definition.GQLType] = modelCfg
g.Models[enum.Definition.GQLDefinition.Name] = modelCfg
}
}

Expand Down
10 changes: 7 additions & 3 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (g *Generator) buildInputs(namedTypes NamedTypes, prog *loader.Program) (Ob
}

sort.Slice(inputs, func(i, j int) bool {
return inputs[i].Definition.GQLType < inputs[j].Definition.GQLType
return inputs[i].Definition.GQLDefinition.Name < inputs[j].Definition.GQLDefinition.Name
})

return inputs, nil
Expand Down Expand Up @@ -69,8 +69,12 @@ func (g *Generator) buildInput(types NamedTypes, typ *ast.Definition) (*Object,
}
}

if !newField.TypeReference.Definition.IsInput && !newField.TypeReference.Definition.IsScalar {
return nil, errors.Errorf("%s cannot be used as a field of %s. only input and scalar types are allowed", newField.Definition.GQLType, obj.Definition.GQLType)
if !newField.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf(
"%s cannot be used as a field of %s. only input and scalar types are allowed",
newField.Definition.GQLDefinition.Name,
obj.Definition.GQLDefinition.Name,
)
}

obj.Fields = append(obj.Fields, newField)
Expand Down
2 changes: 1 addition & 1 deletion codegen/interface_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (g *Generator) buildInterfaces(types NamedTypes, prog *loader.Program) []*I
}

sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Definition.GQLType < interfaces[j].Definition.GQLType
return interfaces[i].Definition.GQLDefinition.Name < interfaces[j].Definition.GQLDefinition.Name
})

return interfaces
Expand Down
2 changes: 1 addition & 1 deletion codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (g *Generator) buildModels(types NamedTypes, prog *loader.Program) ([]Model
}

sort.Slice(models, func(i, j int) bool {
return models[i].Definition.GQLType < models[j].Definition.GQLType
return models[i].Definition.GQLDefinition.Name < models[j].Definition.GQLDefinition.Name
})

return models, nil
Expand Down
20 changes: 10 additions & 10 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type FieldArgument struct {
type Objects []*Object

func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.Definition.GQLType)
satisfiedBy := strconv.Quote(o.Definition.GQLDefinition.Name)
for _, s := range o.Satisfies {
satisfiedBy += ", " + strconv.Quote(s)
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func (o *Object) IsConcurrent() bool {
}

func (o *Object) IsReserved() bool {
return strings.HasPrefix(o.Definition.GQLType, "__")
return strings.HasPrefix(o.Definition.GQLDefinition.Name, "__")
}

func (f *Field) HasDirectives() bool {
Expand Down Expand Up @@ -145,23 +145,23 @@ func (f *Field) ShortInvocation() string {
return ""
}

return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLType, f.GoNameExported(), f.CallArgs())
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLDefinition.Name, f.GoNameExported(), f.CallArgs())
}

func (f *Field) ArgsFunc() string {
if len(f.Args) == 0 {
return ""
}

return "field_" + f.Object.Definition.GQLType + "_" + f.GQLName + "_args"
return "field_" + f.Object.Definition.GQLDefinition.Name + "_" + f.GQLName + "_args"
}

func (f *Field) ResolverType() string {
if !f.IsResolver() {
return ""
}

return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLType, f.GoNameExported(), f.CallArgs())
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.GQLDefinition.Name, f.GoNameExported(), f.CallArgs())
}

func (f *Field) ShortResolverDeclaration() string {
Expand Down Expand Up @@ -190,7 +190,7 @@ func (f *Field) ResolverDeclaration() string {
if !f.IsResolver() {
return ""
}
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.Definition.GQLType, f.GoNameUnexported())
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.Definition.GQLDefinition.Name, f.GoNameUnexported())

if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Definition.GoType))
Expand Down Expand Up @@ -329,13 +329,13 @@ func (f *Field) doWriteJson(val string, destType types.Type, astType *ast.Type,
"index": index,
"top": depth == 1,
"arrayLen": len(val),
"isScalar": f.Definition.IsScalar,
"isScalar": f.Definition.GQLDefinition.Kind == ast.Scalar || f.Definition.GQLDefinition.Kind == ast.Enum,
"usePtr": usePtr,
"next": f.doWriteJson(val+"["+index+"]", destType.Elem(), astType.Elem, false, depth+1),
})

default:
if f.Definition.IsScalar {
if f.Definition.GQLDefinition.Kind == ast.Scalar || f.Definition.GQLDefinition.Kind == ast.Enum {
if isPtr {
val = "*" + val
}
Expand All @@ -347,7 +347,7 @@ func (f *Field) doWriteJson(val string, destType types.Type, astType *ast.Type,
}
return tpl(`
return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
"type": f.Definition.GQLType,
"type": f.Definition.GQLDefinition.Name,
"val": val,
})
}
Expand All @@ -359,7 +359,7 @@ func (f *FieldArgument) Stream() bool {

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.Definition.GQLType, name) {
if strings.EqualFold(o.Definition.GQLDefinition.Name, name) {
return os[i]
}
}
Expand Down
8 changes: 4 additions & 4 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (g *Generator) buildObjects(ts NamedTypes, prog *loader.Program) (Objects,
}

sort.Slice(objects, func(i, j int) bool {
return objects[i].Definition.GQLType < objects[j].Definition.GQLType
return objects[i].Definition.GQLDefinition.Name < objects[j].Definition.GQLDefinition.Name
})

return objects, nil
Expand Down Expand Up @@ -84,7 +84,7 @@ func (g *Generator) buildObject(prog *loader.Program, ts NamedTypes, typ *ast.De
obj := &Object{Definition: ts[typ.Name]}
typeEntry, entryExists := g.Models[typ.Name]

tt := types.NewTypeName(0, g.Config.Exec.Pkg(), obj.Definition.GQLType+"Resolver", nil)
tt := types.NewTypeName(0, g.Config.Exec.Pkg(), obj.Definition.GQLDefinition.Name+"Resolver", nil)
obj.ResolverInterface = types.NewNamed(tt, nil, nil)

if typ == g.schema.Query {
Expand Down Expand Up @@ -168,8 +168,8 @@ func (g *Generator) buildObject(prog *loader.Program, ts NamedTypes, typ *ast.De
Directives: dirs,
}

if !newArg.TypeReference.Definition.IsInput && !newArg.TypeReference.Definition.IsScalar {
return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.Definition.GQLType, field.Name)
if !newArg.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.Definition.GQLDefinition.Name, field.Name)
}

if arg.DefaultValue != nil {
Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/args.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
return nil, err
}
{{- end }}
{{- if $arg.Definition.IsInput }}
{{- if eq $arg.Definition.GQLDefinition.Kind "INPUT_OBJECT" }}
{{ $arg.Middleware (print "arg" $i) (print "arg" $i) }}
{{- end }}
}
Expand Down

0 comments on commit afc773b

Please sign in to comment.