Skip to content

Commit

Permalink
Shared arg unmarshaling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Scarr committed Feb 5, 2019
1 parent 555d746 commit 6047355
Show file tree
Hide file tree
Showing 28 changed files with 1,006 additions and 669 deletions.
53 changes: 37 additions & 16 deletions codegen/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"go/types"
"strings"

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
Expand All @@ -15,26 +16,46 @@ type ArgSet struct {
FuncDecl string
}

type FieldArgument struct {
*ast.ArgumentDefinition
TypeReference *config.TypeReference
VarName string // The name of the var in go
Object *Object // A link back to the parent object
Default interface{} // The default value
Directives []*Directive
Value interface{} // value set in Data
}

func (f *FieldArgument) Stream() bool {
return f.Object != nil && f.Object.Stream
}

func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgument, error) {
def := b.Schema.Types[arg.Type.Name()]
if !def.IsInputType() {
return nil, errors.Errorf(
"cannot use %s as argument %s because %s is not a valid input type",
arg.Type.String(),
arg.Name,
def.Kind,
)
}

tr, err := b.Binder.TypeReference(arg.Type)
if err != nil {
return nil, err
}

argDirs, err := b.getDirectives(arg.Directives)
if err != nil {
return nil, err
}
newArg := FieldArgument{
GQLName: arg.Name,
TypeReference: b.NamedTypes.getType(arg.Type),
Object: obj,
GoVarName: templates.ToGoPrivate(arg.Name),
Directives: argDirs,
}

if !newArg.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf(
"cannot use %s as argument %s because %s is not a valid input type",
newArg.Definition.GQLDefinition.Name,
arg.Name,
newArg.TypeReference.Definition.GQLDefinition.Kind,
)
ArgumentDefinition: arg,
TypeReference: tr,
Object: obj,
VarName: templates.ToGoPrivate(arg.Name),
Directives: argDirs,
}

if arg.DefaultValue != nil {
Expand All @@ -54,8 +75,8 @@ nextArg:
for j := 0; j < params.Len(); j++ {
param := params.At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.TypeReference.GoType = param.Type()
if strings.EqualFold(oldArg.Name, param.Name()) {
oldArg.TypeReference.GO = param.Type()
newArgs = append(newArgs, oldArg)
continue nextArg
}
Expand Down
49 changes: 23 additions & 26 deletions codegen/args.gotpl
Original file line number Diff line number Diff line change
@@ -1,49 +1,46 @@
{{ range $name, $args := .Args }}
func (e *executableSchema){{ $name }}(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
{{- range $i, $arg := . }}
var arg{{$i}} {{$arg.GoType | ref }}
if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {
var arg{{$i}} {{ $arg.TypeReference.GO | ref}}
if tmp, ok := rawArgs[{{$arg.Name|quote}}]; ok {
{{- if $arg.Directives }}
argm{{$i}}, err := chainFieldMiddleware([]graphql.FieldMiddleware{
{{- range $directive := $arg.Directives }}
func(ctx context.Context, n graphql.Resolver) (res interface{}, err error) {
{{- range $dArg := $directive.Args }}
{{- if and $dArg.IsPtr ( notNil "Value" $dArg ) }}
{{ $dArg.GoVarName }} := {{ $dArg.Value | dump }}
getArg0 := func(ctx context.Context) (interface{}, error) { return unmarshal{{$arg.TypeReference.GQL.Name}}2{{ $arg.TypeReference.GO | ts }}(tmp) }

{{- range $i, $directive := $arg.Directives }}
getArg{{add $i 1}} := func(ctx context.Context) (res interface{}, err error) {
{{- range $dArg := $directive.Args }}
{{- if and $dArg.TypeReference.IsPtr ( notNil "Value" $dArg ) }}
{{ $dArg.VarName }} := {{ $dArg.Value | dump }}
{{- end }}
{{- end }}
{{- end }}
n := getArg{{$i}}
return e.directives.{{$directive.Name|ucFirst}}({{$directive.ResolveArgs "tmp" "n" }})
},
{{- end }}
}...)(ctx, func(ctx2 context.Context) (interface{}, error) {
var err error
{{$arg.Unmarshal (print "arg" $i) "tmp" }}
if err != nil {
return nil, err
}
return arg{{ $i }}, nil
})
{{- end }}

tmp, err = getArg{{$arg.Directives|len}}(ctx)
if err != nil {
return nil, err
}
if data, ok := argm{{$i}}.({{$arg.GoType | ref }}); ok {
if data, ok := tmp.({{ $arg.TypeReference.GO }}) ; ok {
arg{{$i}} = data
} else {
return nil, errors.New("expect {{$arg.GoType | ref }}")
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{ $arg.TypeReference.GO }}`, tmp)
}
{{- else }}
var err error
{{ $arg.Unmarshal (print "arg" $i) "tmp" }}
arg{{$i}}, err = unmarshal{{$arg.TypeReference.GQL.Name}}2{{ $arg.TypeReference.GO | ts }}(tmp)
if err != nil {
return nil, err
}
{{- end }}
{{- if eq $arg.Definition.GQLDefinition.Kind "INPUT_OBJECT" }}
{{ $arg.Middleware (print "arg" $i) (print "arg" $i) }}
{{- end }}

{{/*{{- if eq $arg.TypeReference.Definition.Kind "INPUT_OBJECT" }}*/}}
{{/*{{ $arg.Middleware (print "arg" $i) (print "arg" $i) }}*/}}
{{/*{{- end }}*/}}
}
args[{{$arg.GQLName|quote}}] = arg{{$i}}
args[{{$arg.Name|quote}}] = arg{{$i}}
{{- end }}
return args, nil
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/build_typedef.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (b *builder) buildTypeDef(schemaType *ast.Definition) (*TypeDefinition, err

// Special case to reference generated unmarshal functions
if !hasUnmarshal {
t.Unmarshaler = types.NewFunc(0, b.Config.Exec.Pkg(), "Unmarshal"+schemaType.Name, nil)
t.Unmarshaler = types.NewFunc(0, b.Config.Exec.Pkg(), "unmarshalInput"+schemaType.Name, nil)
}

return t, nil
Expand Down
100 changes: 87 additions & 13 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ import (

// Binder connects graphql types to golang types using static analysis
type Binder struct {
pkgs []*packages.Package
types TypeMap
pkgs []*packages.Package
schema *ast.Schema
cfg *Config
References []*TypeReference
}

func (c *Config) NewBinder() (*Binder, error) {
func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
if err != nil {
return nil, err
}

return &Binder{
pkgs: pkgs,
types: c.Models,
pkgs: pkgs,
schema: s,
cfg: c,
}, nil
}

Expand Down Expand Up @@ -55,7 +58,7 @@ var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil,
var InterfaceType = types.NewInterfaceType(nil, nil)

func (b *Binder) FindUserObject(name string) (types.Type, error) {
userEntry, ok := b.types[name]
userEntry, ok := b.cfg.Models[name]
if !ok {
return nil, fmt.Errorf(name + " not found")
}
Expand Down Expand Up @@ -118,17 +121,58 @@ func normalizeVendor(pkg string) string {
return modifiers + parts[len(parts)-1]
}

func (b *Binder) FindBackingType(schemaType *ast.Type) (types.Type, error) {
// TypeReference is used by args and field types. The Definition can refer to both input and output types.
type TypeReference struct {
Definition *ast.Definition
GQL *ast.Type
GO types.Type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
}

func (t TypeReference) IsPtr() bool {
_, isPtr := t.GO.(*types.Pointer)
return isPtr
}

func (t TypeReference) IsSlice() bool {
_, isSlice := t.GO.(*types.Slice)
return isSlice
}

func (t TypeReference) IsNamed() bool {
_, isSlice := t.GO.(*types.Named)
return isSlice
}

func (b *Binder) PushRef(ret *TypeReference) {
b.References = append(b.References, ret)
}

func (b *Binder) TypeReference(schemaType *ast.Type) (ret *TypeReference, err error) {
var pkgName, typeName string
def := b.schema.Types[schemaType.Name()]
defer func() {
if err == nil && ret != nil {
b.PushRef(ret)
}
}()

if userEntry, ok := b.types[schemaType.Name()]; ok && userEntry.Model != "" {
// special case for maps
if userEntry, ok := b.cfg.Models[schemaType.Name()]; ok && userEntry.Model != "" {
if userEntry.Model == "map[string]interface{}" {
return MapType, nil
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: MapType,
}, nil
}

if userEntry.Model == "interface{}" {
return InterfaceType, nil
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: InterfaceType,
}, nil
}

pkgName, typeName = code.PkgAndType(userEntry.Model)
Expand All @@ -141,12 +185,42 @@ func (b *Binder) FindBackingType(schemaType *ast.Type) (types.Type, error) {
typeName = "String"
}

t, err := b.FindType(pkgName, typeName)
ref := &TypeReference{
Definition: def,
GQL: schemaType,
}

obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}

return b.CopyModifiersFromAst(schemaType, true, t), nil
if fun, isFunc := obj.(*types.Func); isFunc {
ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
ref.Marshaler = fun
ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
} else {
ref.GO = obj.Type()
}

if namedType, ok := ref.GO.(*types.Named); ok && ref.Unmarshaler == nil {
hasUnmarshal := false
for i := 0; i < namedType.NumMethods(); i++ {
switch namedType.Method(i).Name() {
case "UnmarshalGQL":
hasUnmarshal = true
}
}

// Special case to reference generated unmarshal functions
if !hasUnmarshal {
ref.Unmarshaler = types.NewFunc(0, b.cfg.Exec.Pkg(), "unmarshalInput"+schemaType.Name(), nil)
}
}

ref.GO = b.CopyModifiersFromAst(schemaType, true, ref.GO)

return ref, nil
}

func (b *Binder) CopyModifiersFromAst(t *ast.Type, usePtr bool, base types.Type) types.Type {
Expand Down
36 changes: 22 additions & 14 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import (
// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
// resolvers or directives automatically (eg grpc, validation)
type Data struct {
Config *config.Config
Schema *ast.Schema
SchemaStr map[string]string
Directives map[string]*Directive
Objects Objects
Inputs Objects
Interfaces map[string]*Interface
Config *config.Config
Schema *ast.Schema
SchemaStr map[string]string
Directives map[string]*Directive
Objects Objects
Inputs Objects
Interfaces map[string]*Interface
ReferencedTypes map[string]*config.TypeReference

QueryRoot *Object
MutationRoot *Object
Expand Down Expand Up @@ -53,7 +54,7 @@ func BuildData(cfg *config.Config) (*Data, error) {

cfg.InjectBuiltins(b.Schema)

b.Binder, err = b.Config.NewBinder()
b.Binder, err = b.Config.NewBinder(b.Schema)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -120,6 +121,11 @@ func BuildData(cfg *config.Config) (*Data, error) {
return nil, err
}

s.ReferencedTypes, err = b.buildTypes()
if err != nil {
return nil, err
}

sort.Slice(s.Objects, func(i, j int) bool {
return s.Objects[i].Definition.Name < s.Objects[j].Definition.Name
})
Expand All @@ -141,6 +147,10 @@ func (b *builder) injectIntrospectionRoots(s *Data) error {
if err != nil {
return errors.Wrap(err, "unable to find root Type introspection type")
}
stringRef, err := b.Binder.TypeReference(ast.NonNullNamedType("String", nil))
if err != nil {
return errors.Wrap(err, "unable to find root string type reference")
}

obj.Fields = append(obj.Fields, &Field{
TypeReference: &TypeReference{b.NamedTypes["__Type"], types.NewPointer(typeType.Type()), ast.NamedType("__Schema", nil)},
Expand All @@ -150,13 +160,11 @@ func (b *builder) injectIntrospectionRoots(s *Data) error {
GoFieldName: "introspectType",
Args: []*FieldArgument{
{
GQLName: "name",
TypeReference: &TypeReference{
b.NamedTypes["String"],
types.Typ[types.String],
ast.NamedType("String", nil),
ArgumentDefinition: &ast.ArgumentDefinition{
Name: "name",
},
Object: &Object{},
TypeReference: stringRef,
Object: &Object{},
},
},
Object: obj,
Expand Down

0 comments on commit 6047355

Please sign in to comment.