From f02dabb7bad39e049323e05b171feb8f387d257a Mon Sep 17 00:00:00 2001 From: Mathew Byrne Date: Mon, 18 Mar 2019 14:39:31 +1100 Subject: [PATCH] Add test case for union pointer --- codegen/testserver/generated.go | 209 ++++++++++++++++++++++++++++++ codegen/testserver/models-gen.go | 16 +++ codegen/testserver/resolver.go | 3 + codegen/testserver/stub.go | 4 + codegen/testserver/useptr.graphql | 13 ++ codegen/testserver/useptr_test.go | 14 ++ 6 files changed, 259 insertions(+) create mode 100644 codegen/testserver/useptr.graphql create mode 100644 codegen/testserver/useptr_test.go diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 99de6ccc5e..caa65855d7 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -56,6 +56,10 @@ type DirectiveRoot struct { } type ComplexityRoot struct { + A struct { + ID func(childComplexity int) int + } + AIt struct { ID func(childComplexity int) int } @@ -72,6 +76,10 @@ type ComplexityRoot struct { Int64 func(childComplexity int) int } + B struct { + ID func(childComplexity int) int + } + Circle struct { Area func(childComplexity int) int Radius func(childComplexity int) int @@ -155,6 +163,7 @@ type ComplexityRoot struct { NestedInputs func(childComplexity int, input [][]*OuterInput) int NestedOutputs func(childComplexity int) int NullableArg func(childComplexity int, arg *int) int + OptionalUnion func(childComplexity int) int Overlapping func(childComplexity int) int Panics func(childComplexity int) int Recursive func(childComplexity int, input *RecursiveInputSlice) int @@ -256,6 +265,7 @@ type QueryResolver interface { Panics(ctx context.Context) (*Panics, error) DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) + OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) } type SubscriptionResolver interface { @@ -281,6 +291,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in _ = ec switch typeName + "." + field { + case "A.ID": + if e.complexity.A.ID == nil { + break + } + + return e.complexity.A.ID(childComplexity), true + case "AIt.ID": if e.complexity.AIt.ID == nil { break @@ -330,6 +347,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Autobind.Int64(childComplexity), true + case "B.ID": + if e.complexity.B.ID == nil { + break + } + + return e.complexity.B.ID(childComplexity), true + case "Circle.Area": if e.complexity.Circle.Area == nil { break @@ -696,6 +720,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.NullableArg(childComplexity, args["arg"].(*int)), true + case "Query.OptionalUnion": + if e.complexity.Query.OptionalUnion == nil { + break + } + + return e.complexity.Query.OptionalUnion(childComplexity), true + case "Query.Overlapping": if e.complexity.Query.Overlapping == nil { break @@ -1255,6 +1286,20 @@ type Slices { test3: [String]! test4: [String!]! } +`}, + &ast.Source{Name: "useptr.graphql", Input: `type A { + id: ID! +} + +type B { + id: ID! +} + +union TestUnion = A | B + +extend type Query { + optionalUnion: TestUnion +} `}, &ast.Source{Name: "validtypes.graphql", Input: `extend type Query { validType: ValidType @@ -1925,6 +1970,33 @@ func (ec *executionContext) field___Type_fields_args(ctx context.Context, rawArg // region **************************** field.gotpl ***************************** +func (ec *executionContext) _A_id(ctx context.Context, field graphql.CollectedField, obj *A) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "A", + Field: field, + Args: nil, + IsMethod: false, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID, nil + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNID2string(ctx, field.Selections, res) +} + func (ec *executionContext) _AIt_id(ctx context.Context, field graphql.CollectedField, obj *AIt) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -2114,6 +2186,33 @@ func (ec *executionContext) _Autobind_idInt(ctx context.Context, field graphql.C return ec.marshalNID2int(ctx, field.Selections, res) } +func (ec *executionContext) _B_id(ctx context.Context, field graphql.CollectedField, obj *B) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "B", + Field: field, + Args: nil, + IsMethod: false, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID, nil + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNID2string(ctx, field.Selections, res) +} + func (ec *executionContext) _Circle_radius(ctx context.Context, field graphql.CollectedField, obj *Circle) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -3581,6 +3680,30 @@ func (ec *executionContext) _Query_slices(ctx context.Context, field graphql.Col return ec.marshalOSlices2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐSlices(ctx, field.Selections, res) } +func (ec *executionContext) _Query_optionalUnion(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().OptionalUnion(rctx) + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(TestUnion) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalOTestUnion2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐTestUnion(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_validType(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -5399,10 +5522,54 @@ func (ec *executionContext) _ShapeUnion(ctx context.Context, sel ast.SelectionSe } } +func (ec *executionContext) _TestUnion(ctx context.Context, sel ast.SelectionSet, obj *TestUnion) graphql.Marshaler { + switch obj := (*obj).(type) { + case nil: + return graphql.Null + case A: + return ec._A(ctx, sel, &obj) + case *A: + return ec._A(ctx, sel, obj) + case B: + return ec._B(ctx, sel, &obj) + case *B: + return ec._B(ctx, sel, obj) + default: + panic(fmt.Errorf("unexpected type %T", obj)) + } +} + // endregion ************************** interface.gotpl *************************** // region **************************** object.gotpl **************************** +var aImplementors = []string{"A", "TestUnion"} + +func (ec *executionContext) _A(ctx context.Context, sel ast.SelectionSet, obj *A) graphql.Marshaler { + fields := graphql.CollectFields(ctx, sel, aImplementors) + + out := graphql.NewFieldSet(fields) + invalid := false + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("A") + case "id": + out.Values[i] = ec._A_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalid = true + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalid { + return graphql.Null + } + return out +} + var aItImplementors = []string{"AIt"} func (ec *executionContext) _AIt(ctx context.Context, sel ast.SelectionSet, obj *AIt) graphql.Marshaler { @@ -5504,6 +5671,33 @@ func (ec *executionContext) _Autobind(ctx context.Context, sel ast.SelectionSet, return out } +var bImplementors = []string{"B", "TestUnion"} + +func (ec *executionContext) _B(ctx context.Context, sel ast.SelectionSet, obj *B) graphql.Marshaler { + fields := graphql.CollectFields(ctx, sel, bImplementors) + + out := graphql.NewFieldSet(fields) + invalid := false + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("B") + case "id": + out.Values[i] = ec._B_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalid = true + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalid { + return graphql.Null + } + return out +} + var circleImplementors = []string{"Circle", "Shape", "ShapeUnion"} func (ec *executionContext) _Circle(ctx context.Context, sel ast.SelectionSet, obj *Circle) graphql.Marshaler { @@ -6280,6 +6474,17 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_slices(ctx, field) return res }) + case "optionalUnion": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_optionalUnion(ctx, field) + return res + }) case "validType": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -7843,6 +8048,10 @@ func (ec *executionContext) marshalOString2ᚖstring(ctx context.Context, sel as return ec.marshalOString2string(ctx, sel, *v) } +func (ec *executionContext) marshalOTestUnion2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐTestUnion(ctx context.Context, sel ast.SelectionSet, v TestUnion) graphql.Marshaler { + return ec._TestUnion(ctx, sel, &v) +} + func (ec *executionContext) unmarshalOThirdParty2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐThirdParty(ctx context.Context, v interface{}) (ThirdParty, error) { return UnmarshalThirdParty(v) } diff --git a/codegen/testserver/models-gen.go b/codegen/testserver/models-gen.go index b09dcfdf99..25f8522713 100644 --- a/codegen/testserver/models-gen.go +++ b/codegen/testserver/models-gen.go @@ -9,6 +9,16 @@ import ( "time" ) +type TestUnion interface { + IsTestUnion() +} + +type A struct { + ID string `json:"id"` +} + +func (A) IsTestUnion() {} + type AIt struct { ID string `json:"id"` } @@ -17,6 +27,12 @@ type AbIt struct { ID string `json:"id"` } +type B struct { + ID string `json:"id"` +} + +func (B) IsTestUnion() {} + type EmbeddedDefaultScalar struct { Value *string `json:"value"` } diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index f42787e6ef..e0fba141e8 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -140,6 +140,9 @@ func (r *queryResolver) DefaultScalar(ctx context.Context, arg string) (string, func (r *queryResolver) Slices(ctx context.Context) (*Slices, error) { panic("not implemented") } +func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { + panic("not implemented") +} func (r *queryResolver) ValidType(ctx context.Context) (*ValidType, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index 3965ce4c20..6706ead178 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -50,6 +50,7 @@ type Stub struct { Panics func(ctx context.Context) (*Panics, error) DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) + OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) } SubscriptionResolver struct { @@ -190,6 +191,9 @@ func (r *stubQuery) DefaultScalar(ctx context.Context, arg string) (string, erro func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { return r.QueryResolver.Slices(ctx) } +func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { + return r.QueryResolver.OptionalUnion(ctx) +} func (r *stubQuery) ValidType(ctx context.Context) (*ValidType, error) { return r.QueryResolver.ValidType(ctx) } diff --git a/codegen/testserver/useptr.graphql b/codegen/testserver/useptr.graphql new file mode 100644 index 0000000000..23c1af0b42 --- /dev/null +++ b/codegen/testserver/useptr.graphql @@ -0,0 +1,13 @@ +type A { + id: ID! +} + +type B { + id: ID! +} + +union TestUnion = A | B + +extend type Query { + optionalUnion: TestUnion +} diff --git a/codegen/testserver/useptr_test.go b/codegen/testserver/useptr_test.go new file mode 100644 index 0000000000..ba088f49dc --- /dev/null +++ b/codegen/testserver/useptr_test.go @@ -0,0 +1,14 @@ +package testserver + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUserPtr(t *testing.T) { + s := &Stub{} + r := reflect.TypeOf(s.QueryResolver.OptionalUnion) + require.True(t, r.Out(0).Kind() == reflect.Interface) +}