From 7aad6ba78fc19e4b89b0eae12263e79f56eeb97b Mon Sep 17 00:00:00 2001 From: Richard Musiol Date: Wed, 31 May 2017 16:25:13 +0200 Subject: [PATCH] refactor: use schema.NamedType --- internal/validation/validation.go | 33 +++++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/internal/validation/validation.go b/internal/validation/validation.go index ac3cec7161..28c40c6a51 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -87,7 +87,7 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError { } } - var entryPoint common.Type + var entryPoint schema.NamedType switch op.Type { case query.Query: entryPoint = s.EntryPoints["query"] @@ -116,7 +116,7 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError { validateName(c, fragNames, frag.Name, "UniqueFragmentNames", "fragment") validateDirectives(opc, "FRAGMENT_DEFINITION", frag.Directives) - t := resolveType(c, &frag.On) + t := unwrapType(resolveType(c, &frag.On)) // continue even if t is nil if t != nil && !canBeFragment(t) { c.addErr(frag.On.Loc, "FragmentsOnCompositeTypes", "Fragment %q cannot condition on non composite type %q.", frag.Name.Name, t) @@ -154,13 +154,13 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError { return c.errs } -func validateSelectionSet(c *opContext, selSet *query.SelectionSet, t common.Type) { +func validateSelectionSet(c *opContext, selSet *query.SelectionSet, t schema.NamedType) { for _, sel := range selSet.Selections { validateSelection(c, sel, t) } } -func validateSelection(c *opContext, sel query.Selection, t common.Type) { +func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { switch sel := sel.(type) { case *query.Field: validateDirectives(c, "FIELD", sel.Directives) @@ -223,7 +223,7 @@ func validateSelection(c *opContext, sel query.Selection, t common.Type) { case *query.InlineFragment: validateDirectives(c, "INLINE_FRAGMENT", sel.Directives) if sel.On.Name != "" { - fragTyp := resolveType(c.context, &sel.On) + fragTyp := unwrapType(resolveType(c.context, &sel.On)) if fragTyp != nil && !compatible(t, fragTyp) { c.addErr(sel.Loc, "PossibleFragmentSpreads", "Fragment cannot be spread here as objects of type %q can never be of type %q.", t, fragTyp) } @@ -373,14 +373,21 @@ func fields(t common.Type) schema.FieldList { } } -func unwrapType(t common.Type) common.Type { - switch t := t.(type) { - case *common.List: - return unwrapType(t.OfType) - case *common.NonNull: - return unwrapType(t.OfType) - default: - return t +func unwrapType(t common.Type) schema.NamedType { + if t == nil { + return nil + } + for { + switch t2 := t.(type) { + case schema.NamedType: + return t2 + case *common.List: + t = t2.OfType + case *common.NonNull: + t = t2.OfType + default: + panic("unreachable") + } } }