From 04cb2550483c3cf827b6668b69383e1f363e8dd4 Mon Sep 17 00:00:00 2001 From: Richard Musiol Date: Tue, 23 May 2017 13:40:02 +0200 Subject: [PATCH] allow structs without pointers Fixes #78. --- example/starwars/starwars.go | 26 ++++++++++++------------ graphql_test.go | 32 +++++++++++++++--------------- internal/exec/resolvable/packer.go | 16 ++++++++++++--- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/example/starwars/starwars.go b/example/starwars/starwars.go index b482401a9a..55754c03ed 100644 --- a/example/starwars/starwars.go +++ b/example/starwars/starwars.go @@ -286,14 +286,14 @@ var reviews = make(map[string][]*review) type Resolver struct{} -func (r *Resolver) Hero(args *struct{ Episode string }) *characterResolver { +func (r *Resolver) Hero(args struct{ Episode string }) *characterResolver { if args.Episode == "EMPIRE" { return &characterResolver{&humanResolver{humanData["1000"]}} } return &characterResolver{&droidResolver{droidData["2001"]}} } -func (r *Resolver) Reviews(args *struct{ Episode string }) []*reviewResolver { +func (r *Resolver) Reviews(args struct{ Episode string }) []*reviewResolver { var l []*reviewResolver for _, review := range reviews[args.Episode] { l = append(l, &reviewResolver{review}) @@ -301,7 +301,7 @@ func (r *Resolver) Reviews(args *struct{ Episode string }) []*reviewResolver { return l } -func (r *Resolver) Search(args *struct{ Text string }) []*searchResultResolver { +func (r *Resolver) Search(args struct{ Text string }) []*searchResultResolver { var l []*searchResultResolver for _, h := range humans { if strings.Contains(h.Name, args.Text) { @@ -321,7 +321,7 @@ func (r *Resolver) Search(args *struct{ Text string }) []*searchResultResolver { return l } -func (r *Resolver) Character(args *struct{ ID graphql.ID }) *characterResolver { +func (r *Resolver) Character(args struct{ ID graphql.ID }) *characterResolver { if h := humanData[args.ID]; h != nil { return &characterResolver{&humanResolver{h}} } @@ -331,21 +331,21 @@ func (r *Resolver) Character(args *struct{ ID graphql.ID }) *characterResolver { return nil } -func (r *Resolver) Human(args *struct{ ID graphql.ID }) *humanResolver { +func (r *Resolver) Human(args struct{ ID graphql.ID }) *humanResolver { if h := humanData[args.ID]; h != nil { return &humanResolver{h} } return nil } -func (r *Resolver) Droid(args *struct{ ID graphql.ID }) *droidResolver { +func (r *Resolver) Droid(args struct{ ID graphql.ID }) *droidResolver { if d := droidData[args.ID]; d != nil { return &droidResolver{d} } return nil } -func (r *Resolver) Starship(args *struct{ ID graphql.ID }) *starshipResolver { +func (r *Resolver) Starship(args struct{ ID graphql.ID }) *starshipResolver { if s := starshipData[args.ID]; s != nil { return &starshipResolver{s} } @@ -373,7 +373,7 @@ type character interface { ID() graphql.ID Name() string Friends() *[]*characterResolver - FriendsConnection(*friendsConenctionArgs) (*friendsConnectionResolver, error) + FriendsConnection(friendsConenctionArgs) (*friendsConnectionResolver, error) AppearsIn() []string } @@ -403,7 +403,7 @@ func (r *humanResolver) Name() string { return r.h.Name } -func (r *humanResolver) Height(args *struct{ Unit string }) float64 { +func (r *humanResolver) Height(args struct{ Unit string }) float64 { return convertLength(r.h.Height, args.Unit) } @@ -419,7 +419,7 @@ func (r *humanResolver) Friends() *[]*characterResolver { return resolveCharacters(r.h.Friends) } -func (r *humanResolver) FriendsConnection(args *friendsConenctionArgs) (*friendsConnectionResolver, error) { +func (r *humanResolver) FriendsConnection(args friendsConenctionArgs) (*friendsConnectionResolver, error) { return newFriendsConnectionResolver(r.h.Friends, args) } @@ -451,7 +451,7 @@ func (r *droidResolver) Friends() *[]*characterResolver { return resolveCharacters(r.d.Friends) } -func (r *droidResolver) FriendsConnection(args *friendsConenctionArgs) (*friendsConnectionResolver, error) { +func (r *droidResolver) FriendsConnection(args friendsConenctionArgs) (*friendsConnectionResolver, error) { return newFriendsConnectionResolver(r.d.Friends, args) } @@ -478,7 +478,7 @@ func (r *starshipResolver) Name() string { return r.s.Name } -func (r *starshipResolver) Length(args *struct{ Unit string }) float64 { +func (r *starshipResolver) Length(args struct{ Unit string }) float64 { return convertLength(r.s.Length, args.Unit) } @@ -550,7 +550,7 @@ type friendsConnectionResolver struct { to int } -func newFriendsConnectionResolver(ids []graphql.ID, args *friendsConenctionArgs) (*friendsConnectionResolver, error) { +func newFriendsConnectionResolver(ids []graphql.ID, args friendsConenctionArgs) (*friendsConnectionResolver, error) { from := 0 if args.After != nil { b, err := base64.StdEncoding.DecodeString(string(*args.After)) diff --git a/graphql_test.go b/graphql_test.go index b52609682f..95d15b1dd4 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -28,7 +28,7 @@ func (r *helloSnakeResolver1) HelloHTML() string { return "Hello snake!" } -func (r *helloSnakeResolver1) SayHello(args *struct{ FullName string }) string { +func (r *helloSnakeResolver1) SayHello(args struct{ FullName string }) string { return "Hello " + args.FullName + "!" } @@ -38,7 +38,7 @@ func (r *helloSnakeResolver2) HelloHTML(ctx context.Context) (string, error) { return "Hello snake!", nil } -func (r *helloSnakeResolver2) SayHello(ctx context.Context, args *struct{ FullName string }) (string, error) { +func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullName string }) (string, error) { return "Hello " + args.FullName + "!", nil } @@ -50,14 +50,14 @@ func (r *theNumberResolver) TheNumber() int32 { return r.number } -func (r *theNumberResolver) ChangeTheNumber(args *struct{ NewNumber int32 }) *theNumberResolver { +func (r *theNumberResolver) ChangeTheNumber(args struct{ NewNumber int32 }) *theNumberResolver { r.number = args.NewNumber return r } type timeResolver struct{} -func (r *timeResolver) AddHour(args *struct{ Time graphql.Time }) graphql.Time { +func (r *timeResolver) AddHour(args struct{ Time graphql.Time }) graphql.Time { return graphql.Time{Time: args.Time.Add(time.Hour)} } @@ -1507,7 +1507,7 @@ func TestTime(t *testing.T) { type resolverWithUnexportedMethod struct{} -func (r *resolverWithUnexportedMethod) changeTheNumber(args *struct{ NewNumber int32 }) int32 { +func (r *resolverWithUnexportedMethod) changeTheNumber(args struct{ NewNumber int32 }) int32 { return args.NewNumber } @@ -1528,7 +1528,7 @@ func TestUnexportedMethod(t *testing.T) { type resolverWithUnexportedField struct{} -func (r *resolverWithUnexportedField) ChangeTheNumber(args *struct{ newNumber int32 }) int32 { +func (r *resolverWithUnexportedField) ChangeTheNumber(args struct{ newNumber int32 }) int32 { return args.newNumber } @@ -1549,27 +1549,27 @@ func TestUnexportedField(t *testing.T) { type inputResolver struct{} -func (r *inputResolver) Int(args *struct{ Value int32 }) int32 { +func (r *inputResolver) Int(args struct{ Value int32 }) int32 { return args.Value } -func (r *inputResolver) Float(args *struct{ Value float64 }) float64 { +func (r *inputResolver) Float(args struct{ Value float64 }) float64 { return args.Value } -func (r *inputResolver) String(args *struct{ Value string }) string { +func (r *inputResolver) String(args struct{ Value string }) string { return args.Value } -func (r *inputResolver) Boolean(args *struct{ Value bool }) bool { +func (r *inputResolver) Boolean(args struct{ Value bool }) bool { return args.Value } -func (r *inputResolver) Nullable(args *struct{ Value *int32 }) *int32 { +func (r *inputResolver) Nullable(args struct{ Value *int32 }) *int32 { return args.Value } -func (r *inputResolver) List(args *struct{ Value []*struct{ V int32 } }) []int32 { +func (r *inputResolver) List(args struct{ Value []*struct{ V int32 } }) []int32 { l := make([]int32, len(args.Value)) for i, entry := range args.Value { l[i] = entry.V @@ -1577,7 +1577,7 @@ func (r *inputResolver) List(args *struct{ Value []*struct{ V int32 } }) []int32 return l } -func (r *inputResolver) NullableList(args *struct{ Value *[]*struct{ V int32 } }) *[]*int32 { +func (r *inputResolver) NullableList(args struct{ Value *[]*struct{ V int32 } }) *[]*int32 { if args.Value == nil { return nil } @@ -1590,11 +1590,11 @@ func (r *inputResolver) NullableList(args *struct{ Value *[]*struct{ V int32 } } return &l } -func (r *inputResolver) Enum(args *struct{ Value string }) string { +func (r *inputResolver) Enum(args struct{ Value string }) string { return args.Value } -func (r *inputResolver) NullableEnum(args *struct{ Value *string }) *string { +func (r *inputResolver) NullableEnum(args struct{ Value *string }) *string { return args.Value } @@ -1602,7 +1602,7 @@ type recursive struct { Next *recursive } -func (r *inputResolver) Recursive(args *struct{ Value *recursive }) int32 { +func (r *inputResolver) Recursive(args struct{ Value *recursive }) int32 { n := int32(0) v := args.Value for v != nil { diff --git a/internal/exec/resolvable/packer.go b/internal/exec/resolvable/packer.go index c21d672b08..115f40b0bc 100644 --- a/internal/exec/resolvable/packer.go +++ b/internal/exec/resolvable/packer.go @@ -122,10 +122,15 @@ func (b *execBuilder) makeNonNullPacker(schemaType common.Type, reflectType refl } func (b *execBuilder) makeStructPacker(values common.InputValueList, typ reflect.Type) (*StructPacker, error) { - if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct { - return nil, fmt.Errorf("expected pointer to struct, got %s", typ) + structType := typ + usePtr := false + if typ.Kind() == reflect.Ptr { + structType = typ.Elem() + usePtr = true + } + if structType.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct or pointer to struct, got %s", typ) } - structType := typ.Elem() var fields []*structPackerField for _, v := range values { @@ -158,6 +163,7 @@ func (b *execBuilder) makeStructPacker(values common.InputValueList, typ reflect p := &StructPacker{ structType: structType, + usePtr: usePtr, fields: fields, } b.structPackers = append(b.structPackers, p) @@ -166,6 +172,7 @@ func (b *execBuilder) makeStructPacker(values common.InputValueList, typ reflect type StructPacker struct { structType reflect.Type + usePtr bool defaultStruct reflect.Value fields []*structPackerField } @@ -193,6 +200,9 @@ func (p *StructPacker) Pack(r *Request, value interface{}) (reflect.Value, error v.Elem().FieldByIndex(f.fieldIndex).Set(packed) } } + if !p.usePtr { + return v.Elem(), nil + } return v, nil }