diff --git a/graphql_test.go b/graphql_test.go index 8c1e142283..b52609682f 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -1688,3 +1688,59 @@ func TestInput(t *testing.T) { }, }) } + +func TestComposedFragments(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: starwarsSchema, + Query: ` + { + composed: hero(episode: EMPIRE) { + name + ...friendsNames + ...friendsIds + } + } + + fragment friendsNames on Character { + name + friends { + name + } + } + + fragment friendsIds on Character { + name + friends { + id + } + } + `, + ExpectedResult: ` + { + "composed": { + "name": "Luke Skywalker", + "friends": [ + { + "id": "1002", + "name": "Han Solo" + }, + { + "id": "1003", + "name": "Leia Organa" + }, + { + "id": "2000", + "name": "C-3PO" + }, + { + "id": "2001", + "name": "R2-D2" + } + ] + } + } + `, + }, + }) +} diff --git a/internal/exec/exec.go b/internal/exec/exec.go index b14c197ea6..2f8a52323d 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -55,26 +55,28 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *query.O return out.Bytes(), r.Errs } -type fieldWithResolver struct { +type fieldToExec struct { field *selected.SchemaField + sels []selected.Selection resolver reflect.Value - out bytes.Buffer + out *bytes.Buffer } func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, resolver reflect.Value, out *bytes.Buffer, serially bool) { async := !serially && selected.HasAsyncSel(sels) - var fields []*fieldWithResolver - collectFieldsToResolve(sels, resolver, &fields) + var fields []*fieldToExec + collectFieldsToResolve(sels, resolver, &fields, make(map[string]*fieldToExec)) if async { var wg sync.WaitGroup wg.Add(len(fields)) for _, f := range fields { - go func(f *fieldWithResolver) { + go func(f *fieldToExec) { defer wg.Done() defer r.handlePanic(ctx) - r.execFieldSelection(ctx, f.field, f.resolver, &f.out, false) + f.out = new(bytes.Buffer) + execFieldSelection(ctx, r, f, false) }(f) } wg.Wait() @@ -93,16 +95,23 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, out.Write(f.out.Bytes()) continue } - r.execFieldSelection(ctx, f.field, f.resolver, out, false) + f.out = out + execFieldSelection(ctx, r, f, false) } out.WriteByte('}') } -func collectFieldsToResolve(sels []selected.Selection, resolver reflect.Value, fields *[]*fieldWithResolver) { +func collectFieldsToResolve(sels []selected.Selection, resolver reflect.Value, fields *[]*fieldToExec, fieldByAlias map[string]*fieldToExec) { for _, sel := range sels { switch sel := sel.(type) { case *selected.SchemaField: - *fields = append(*fields, &fieldWithResolver{field: sel, resolver: resolver}) + field, ok := fieldByAlias[sel.Alias] + if !ok { // validation already checked for conflict (TODO) + field = &fieldToExec{field: sel, resolver: resolver} + fieldByAlias[sel.Alias] = field + *fields = append(*fields, field) + } + field.sels = append(field.sels, sel.Sels...) case *selected.TypenameField: sf := &selected.SchemaField{ @@ -110,14 +119,14 @@ func collectFieldsToResolve(sels []selected.Selection, resolver reflect.Value, f Alias: sel.Alias, FixedResult: reflect.ValueOf(typeOf(sel, resolver)), } - *fields = append(*fields, &fieldWithResolver{field: sf, resolver: resolver}) + *fields = append(*fields, &fieldToExec{field: sf, resolver: resolver}) case *selected.TypeAssertion: out := resolver.Method(sel.MethodIndex).Call(nil) if !out[1].Bool() { continue } - collectFieldsToResolve(sel.Sels, out[0], fields) + collectFieldsToResolve(sel.Sels, out[0], fields, fieldByAlias) default: panic("unreachable") @@ -138,7 +147,7 @@ func typeOf(tf *selected.TypenameField, resolver reflect.Value) string { return "" } -func (r *Request) execFieldSelection(ctx context.Context, field *selected.SchemaField, resolver reflect.Value, out *bytes.Buffer, applyLimiter bool) { +func execFieldSelection(ctx context.Context, r *Request, f *fieldToExec, applyLimiter bool) { if applyLimiter { r.Limiter <- struct{}{} } @@ -146,7 +155,7 @@ func (r *Request) execFieldSelection(ctx context.Context, field *selected.Schema var result reflect.Value var err *errors.QueryError - traceCtx, finish := r.Tracer.TraceField(ctx, field.TraceLabel, field.TypeName, field.Name, !field.Async, field.Args) + traceCtx, finish := r.Tracer.TraceField(ctx, f.field.TraceLabel, f.field.TypeName, f.field.Name, !f.field.Async, f.field.Args) defer func() { finish(err) }() @@ -159,8 +168,8 @@ func (r *Request) execFieldSelection(ctx context.Context, field *selected.Schema } }() - if field.FixedResult.IsValid() { - result = field.FixedResult + if f.field.FixedResult.IsValid() { + result = f.field.FixedResult return nil } @@ -169,15 +178,15 @@ func (r *Request) execFieldSelection(ctx context.Context, field *selected.Schema } var in []reflect.Value - if field.HasContext { + if f.field.HasContext { in = append(in, reflect.ValueOf(traceCtx)) } - if field.ArgsPacker != nil { - in = append(in, field.PackedArgs) + if f.field.ArgsPacker != nil { + in = append(in, f.field.PackedArgs) } - callOut := resolver.Method(field.MethodIndex).Call(in) + callOut := f.resolver.Method(f.field.MethodIndex).Call(in) result = callOut[0] - if field.HasError && !callOut[1].IsNil() { + if f.field.HasError && !callOut[1].IsNil() { resolverErr := callOut[1].Interface().(error) err := errors.Errorf("%s", resolverErr) err.ResolverError = resolverErr @@ -192,11 +201,11 @@ func (r *Request) execFieldSelection(ctx context.Context, field *selected.Schema if err != nil { r.AddError(err) - out.WriteString("null") // TODO handle non-nil + f.out.WriteString("null") // TODO handle non-nil return } - r.execSelectionSet(traceCtx, field.Sels, field.Type, result, out) + r.execSelectionSet(traceCtx, f.sels, f.field.Type, result, f.out) } func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selection, typ common.Type, resolver reflect.Value, out *bytes.Buffer) {