From a2cce0d14984402fccd9e8abb9f286a80e7296fb Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Mon, 18 Mar 2019 13:46:19 +1100 Subject: [PATCH] Use graphql.String for types wrapping a basic string --- codegen/config/binder.go | 95 +++++++++++++++---------- codegen/testserver/generated.go | 95 +++++++++++++++++++++++++ codegen/testserver/gqlgen.yml | 2 + codegen/testserver/models.go | 8 +++ codegen/testserver/resolver.go | 3 + codegen/testserver/stub.go | 4 ++ codegen/testserver/typefallback.graphql | 9 +++ codegen/testserver/typefallback_test.go | 30 ++++++++ codegen/type.gotpl | 19 +++-- 9 files changed, 220 insertions(+), 45 deletions(-) create mode 100644 codegen/testserver/typefallback.graphql create mode 100644 codegen/testserver/typefallback_test.go diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 98ceba8ebb..5e8c1cf696 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -158,9 +158,11 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { newRef := &TypeReference{ GO: types.NewPointer(ref.GO), GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, + IsMarshaler: ref.IsMarshaler, } b.References = append(b.References, newRef) @@ -172,8 +174,10 @@ type TypeReference struct { Definition *ast.Definition GQL *ast.Type GO types.Type + CastType types.Type // Before calling marshalling functions cast from/to this base type Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function + IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler } func (ref *TypeReference) Elem() *TypeReference { @@ -181,9 +185,11 @@ func (ref *TypeReference) Elem() *TypeReference { return &TypeReference{ GO: p.Elem(), GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, + IsMarshaler: ref.IsMarshaler, } } @@ -191,9 +197,11 @@ func (ref *TypeReference) Elem() *TypeReference { return &TypeReference{ GO: s.Elem(), GQL: ref.GQL.Elem, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, + IsMarshaler: ref.IsMarshaler, } } return nil @@ -249,44 +257,6 @@ func (t *TypeReference) HasIsZero() bool { return false } -func (t *TypeReference) SelfMarshalling() bool { - it := t.GO - if ptr, isPtr := it.(*types.Pointer); isPtr { - it = ptr.Elem() - } - namedType, ok := it.(*types.Named) - if !ok { - return false - } - - for i := 0; i < namedType.NumMethods(); i++ { - switch namedType.Method(i).Name() { - case "MarshalGQL": - return true - } - } - return false -} - -func (t *TypeReference) SelfUnmarshalling() bool { - it := t.GO - if ptr, isPtr := it.(*types.Pointer); isPtr { - it = ptr.Elem() - } - namedType, ok := it.(*types.Named) - if !ok { - return false - } - - for i := 0; i < namedType.NumMethods(); i++ { - switch namedType.Method(i).Name() { - case "UnmarshalGQL": - return true - } - } - return false -} - func (t *TypeReference) UniquenessKey() string { var nullability = "O" if t.GQL.NonNull { @@ -395,6 +365,22 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() ref.Marshaler = fun ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) + } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { + ref.GO = obj.Type() + ref.IsMarshaler = true + } else if underlying := basicUnderlying(obj.Type()); underlying != nil && underlying.Kind() == types.String { + // Special case for named types wrapping strings. Used by default enum implementations. + + ref.GO = obj.Type() + ref.CastType = underlying + + underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) + if err != nil { + return nil, err + } + + ref.Marshaler = underlyingRef.Marshaler + ref.Unmarshaler = underlyingRef.Unmarshaler } else { ref.GO = obj.Type() } @@ -430,3 +416,36 @@ func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { return base } + +func hasMethod(it types.Type, name string) bool { + if ptr, isPtr := it.(*types.Pointer); isPtr { + it = ptr.Elem() + } + namedType, ok := it.(*types.Named) + if !ok { + return false + } + + for i := 0; i < namedType.NumMethods(); i++ { + if namedType.Method(i).Name() == name { + return true + } + } + return false +} + +func basicUnderlying(it types.Type) *types.Basic { + if ptr, isPtr := it.(*types.Pointer); isPtr { + it = ptr.Elem() + } + namedType, ok := it.(*types.Named) + if !ok { + return nil + } + + if basic, ok := namedType.Underlying().(*types.Basic); ok { + return basic + } + + return nil +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index caa65855d7..2850867528 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -155,6 +155,7 @@ type ComplexityRoot struct { DirectiveInputType func(childComplexity int, arg InnerInput) int DirectiveNullableArg func(childComplexity int, arg *int, arg2 *int) int ErrorBubble func(childComplexity int) int + Fallback func(childComplexity int, arg FallbackToStringEncoding) int InputSlice func(childComplexity int, arg []string) int InvalidIdentifier func(childComplexity int) int MapInput func(childComplexity int, input map[string]interface{}) int @@ -265,6 +266,7 @@ type QueryResolver interface { Panics(ctx context.Context) (*Panics, error) DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) + Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) } @@ -639,6 +641,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.ErrorBubble(childComplexity), true + case "Query.Fallback": + if e.complexity.Query.Fallback == nil { + break + } + + args, err := ec.field_Query_fallback_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Fallback(childComplexity, args["arg"].(FallbackToStringEncoding)), true + case "Query.InputSlice": if e.complexity.Query.InputSlice == nil { break @@ -1286,6 +1300,16 @@ type Slices { test3: [String]! test4: [String!]! } +`}, + &ast.Source{Name: "typefallback.graphql", Input: `extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} `}, &ast.Source{Name: "useptr.graphql", Input: `type A { id: ID! @@ -1612,6 +1636,20 @@ func (ec *executionContext) field_Query_directiveNullableArg_args(ctx context.Co return args, nil } +func (ec *executionContext) field_Query_fallback_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 FallbackToStringEncoding + if tmp, ok := rawArgs["arg"]; ok { + arg0, err = ec.unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, tmp) + if err != nil { + return nil, err + } + } + args["arg"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_inputSlice_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -3680,6 +3718,40 @@ func (ec *executionContext) _Query_slices(ctx context.Context, field graphql.Col return ec.marshalOSlices2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐSlices(ctx, field.Selections, res) } +func (ec *executionContext) _Query_fallback(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_fallback_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + rctx.Args = args + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Fallback(rctx, args["arg"].(FallbackToStringEncoding)) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(FallbackToStringEncoding) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_optionalUnion(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -6474,6 +6546,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_slices(ctx, field) return res }) + case "fallback": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_fallback(ctx, field) + if res == graphql.Null { + invalid = true + } + return res + }) case "optionalUnion": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -7056,6 +7142,15 @@ func (ec *executionContext) marshalNDefaultScalarImplementation2string(ctx conte return graphql.MarshalString(v) } +func (ec *executionContext) unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, v interface{}) (FallbackToStringEncoding, error) { + tmp, err := graphql.UnmarshalString(v) + return FallbackToStringEncoding(tmp), err +} + +func (ec *executionContext) marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, sel ast.SelectionSet, v FallbackToStringEncoding) graphql.Marshaler { + return graphql.MarshalString(string(v)) +} + func (ec *executionContext) unmarshalNID2int(ctx context.Context, v interface{}) (int, error) { return graphql.UnmarshalIntID(v) } diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index 9979c5efc9..b3e4749654 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -66,3 +66,5 @@ models: oneFoo: { fieldName: foo } twoFoo: { fieldName: foo } oldFoo: { fieldName: foo, resolver: true } + FallbackToStringEncoding: + model: "github.com/99designs/gqlgen/codegen/testserver.FallbackToStringEncoding" diff --git a/codegen/testserver/models.go b/codegen/testserver/models.go index 20af894768..3ee710a50f 100644 --- a/codegen/testserver/models.go +++ b/codegen/testserver/models.go @@ -76,3 +76,11 @@ type OverlappingFields struct { Foo int NewFoo int } + +type FallbackToStringEncoding string + +const ( + FallbackToStringEncodingA FallbackToStringEncoding = "A" + FallbackToStringEncodingB FallbackToStringEncoding = "B" + FallbackToStringEncodingC FallbackToStringEncoding = "C" +) diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index e0fba141e8..b00e7ee0b6 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -140,6 +140,9 @@ func (r *queryResolver) DefaultScalar(ctx context.Context, arg string) (string, func (r *queryResolver) Slices(ctx context.Context) (*Slices, error) { panic("not implemented") } +func (r *queryResolver) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + panic("not implemented") +} func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index 6706ead178..36a100b96f 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -50,6 +50,7 @@ type Stub struct { Panics func(ctx context.Context) (*Panics, error) DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) + Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) } @@ -191,6 +192,9 @@ func (r *stubQuery) DefaultScalar(ctx context.Context, arg string) (string, erro func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { return r.QueryResolver.Slices(ctx) } +func (r *stubQuery) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return r.QueryResolver.Fallback(ctx, arg) +} func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { return r.QueryResolver.OptionalUnion(ctx) } diff --git a/codegen/testserver/typefallback.graphql b/codegen/testserver/typefallback.graphql new file mode 100644 index 0000000000..e1ff1a59d7 --- /dev/null +++ b/codegen/testserver/typefallback.graphql @@ -0,0 +1,9 @@ +extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} diff --git a/codegen/testserver/typefallback_test.go b/codegen/testserver/typefallback_test.go new file mode 100644 index 0000000000..0b0f83135e --- /dev/null +++ b/codegen/testserver/typefallback_test.go @@ -0,0 +1,30 @@ +package testserver + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" +) + +func TestTypeFallback(t *testing.T) { + resolvers := &Stub{} + + srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) + c := client.New(srv.URL) + + resolvers.QueryResolver.Fallback = func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return arg, nil + } + + t.Run("fallback to string passthrough", func(t *testing.T) { + var resp struct { + Fallback string + } + c.MustPost(`query { fallback(arg: A) }`, &resp) + require.Equal(t, "A", resp.Fallback) + }) +} diff --git a/codegen/type.gotpl b/codegen/type.gotpl index 163c95ac40..f727baaca3 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -27,10 +27,15 @@ return res, nil {{- else }} {{- if $type.Unmarshaler }} - return {{ $type.Unmarshaler | call }}(v) + {{- if $type.CastType }} + tmp, err := {{ $type.Unmarshaler | call }}(v) + return {{ $type.GO | ref }}(tmp), err + {{- else}} + return {{ $type.Unmarshaler | call }}(v) + {{- end }} {{- else if eq ($type.GO | ref) "map[string]interface{}" }} return v.(map[string]interface{}), nil - {{- else if $type.SelfUnmarshalling -}} + {{- else if $type.IsMarshaler -}} var res {{ $type.GO | ref }} return res, res.UnmarshalGQL(v) {{- else }} @@ -62,9 +67,7 @@ } {{- end }} - {{- if $type.SelfMarshalling }} - return v - {{- else if $type.IsSlice }} + {{- if $type.IsSlice }} {{- if not $type.GQL.NonNull }} if v == nil { return graphql.Null @@ -111,11 +114,13 @@ return ret {{- else }} - {{- if $type.Marshaler }} + {{- if $type.IsMarshaler }} + return v + {{- else if $type.Marshaler }} {{- if $type.IsPtr }} return ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, *v) {{- else }} - return {{ $type.Marshaler | call }}(v) + return {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}(v){{else}}v{{- end }}) {{- end }} {{- else }} return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)