diff --git a/api/generate.go b/api/generate.go index 44f32c4eaf..b64626401b 100644 --- a/api/generate.go +++ b/api/generate.go @@ -9,9 +9,7 @@ import ( "github.com/99designs/gqlgen/plugin/federation" "github.com/99designs/gqlgen/plugin/modelgen" "github.com/99designs/gqlgen/plugin/resolvergen" - "github.com/99designs/gqlgen/plugin/schemaconfig" "github.com/pkg/errors" - "golang.org/x/tools/go/packages" ) func Generate(cfg *config.Config, option ...Option) error { @@ -19,11 +17,8 @@ func Generate(cfg *config.Config, option ...Option) error { if cfg.Model.IsDefined() { _ = syscall.Unlink(cfg.Model.Filename) } - if err := cfg.Check(); err != nil { - return errors.Wrap(err, "generating core failed") - } - plugins := []plugin.Plugin{schemaconfig.New()} + plugins := []plugin.Plugin{} if cfg.Model.IsDefined() { plugins = append(plugins, modelgen.New()) } @@ -36,16 +31,30 @@ func Generate(cfg *config.Config, option ...Option) error { o(cfg, &plugins) } - schemaMutators := []codegen.SchemaMutator{} for _, p := range plugins { if inj, ok := p.(plugin.SourcesInjector); ok { inj.InjectSources(cfg) } - if mut, ok := p.(codegen.SchemaMutator); ok { - schemaMutators = append(schemaMutators, mut) + } + + err := cfg.LoadSchema() + if err != nil { + return errors.Wrap(err, "failed to load schema") + } + + for _, p := range plugins { + if mut, ok := p.(plugin.SchemaMutator); ok { + err := mut.MutateSchema(cfg.Schema) + if err != nil { + return errors.Wrap(err, p.Name()) + } } } + if err := cfg.Init(); err != nil { + return errors.Wrap(err, "generating core failed") + } + for _, p := range plugins { if mut, ok := p.(plugin.ConfigMutator); ok { err := mut.MutateConfig(cfg) @@ -55,7 +64,7 @@ func Generate(cfg *config.Config, option ...Option) error { } } // Merge again now that the generated models have been injected into the typemap - data, err := codegen.BuildData(cfg, schemaMutators) + data, err := codegen.BuildData(cfg) if err != nil { return errors.Wrap(err, "merging type systems failed") } @@ -95,17 +104,11 @@ func validate(cfg *config.Config) error { if cfg.Resolver.IsDefined() { roots = append(roots, cfg.Resolver.ImportPath()) } - _, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | - packages.NeedFiles | - packages.NeedCompiledGoFiles | - packages.NeedImports | - packages.NeedTypes | - packages.NeedTypesSizes | - packages.NeedSyntax | - packages.NeedTypesInfo}, roots...) - if err != nil { - return errors.Wrap(err, "validation failed") + + cfg.Packages.LoadAll(roots...) + errs := cfg.Packages.Errors() + if len(errs) > 0 { + return errs } return nil } diff --git a/codegen/config/binder.go b/codegen/config/binder.go index d42b9ec523..26c9cdb731 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -1,7 +1,6 @@ package config import ( - "bytes" "fmt" "go/token" "go/types" @@ -10,73 +9,22 @@ import ( "github.com/99designs/gqlgen/internal/code" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" - "golang.org/x/tools/go/packages" ) // Binder connects graphql types to golang types using static analysis type Binder struct { - pkgs map[string]*packages.Package + pkgs *code.Packages schema *ast.Schema cfg *Config References []*TypeReference - PkgErrors PkgErrors SawInvalid bool } -func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) { - pkgs, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | - packages.NeedFiles | - packages.NeedCompiledGoFiles | - packages.NeedImports | - packages.NeedTypes | - packages.NeedTypesSizes | - packages.NeedSyntax | - packages.NeedTypesInfo, - }, c.Models.ReferencedPackages()...) - if err != nil { - return nil, err - } - - mp := map[string]*packages.Package{} - var pkgErrs PkgErrors - for _, p := range pkgs { - populatePkg(mp, p) - for _, e := range p.Errors { - if e.Kind == packages.ListError { - return nil, e - } - } - pkgErrs = append(pkgErrs, p.Errors...) - } - +func (c *Config) NewBinder() *Binder { return &Binder{ - pkgs: mp, - schema: s, - cfg: c, - PkgErrors: pkgErrs, - }, nil -} - -type PkgErrors []packages.Error - -func (p PkgErrors) Error() string { - var b bytes.Buffer - b.WriteString("packages.Load: ") - for _, e := range p { - b.WriteString(e.Error() + "\n") - } - return b.String() -} - -func populatePkg(mp map[string]*packages.Package, p *packages.Package) { - imp := code.NormalizeVendor(p.PkgPath) - if _, ok := mp[imp]; ok { - return - } - mp[imp] = p - for _, p := range p.Imports { - populatePkg(mp, p) + pkgs: c.Packages, + schema: c.Schema, + cfg: c, } } @@ -97,7 +45,7 @@ func (b *Binder) ObjectPosition(typ types.Object) token.Position { Filename: "unknown", } } - pkg := b.getPkg(typ.Pkg().Path()) + pkg := b.pkgs.Load(typ.Pkg().Path()) return pkg.Fset.Position(typ.Pos()) } @@ -128,14 +76,6 @@ func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { return obj.Type(), nil } -func (b *Binder) getPkg(find string) *packages.Package { - imp := code.NormalizeVendor(find) - if p, ok := b.pkgs[imp]; ok { - return p - } - return nil -} - var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete()) var InterfaceType = types.NewInterfaceType(nil, nil) @@ -175,7 +115,7 @@ func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, erro fullName = pkgName + "." + typeName } - pkg := b.getPkg(pkgName) + pkg := b.pkgs.LoadWithTypes(pkgName) if pkg == nil { return nil, errors.Errorf("required package was not loaded: %s", fullName) } diff --git a/codegen/config/binder_test.go b/codegen/config/binder_test.go index 6a842a12da..a5c35bd185 100644 --- a/codegen/config/binder_test.go +++ b/codegen/config/binder_test.go @@ -4,6 +4,8 @@ import ( "go/types" "testing" + "github.com/99designs/gqlgen/internal/code" + "github.com/stretchr/testify/require" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" @@ -49,8 +51,9 @@ func createBinder(cfg Config) (*Binder, *ast.Schema) { Model: []string{"github.com/99designs/gqlgen/example/chat.Message"}, }, } + cfg.Packages = &code.Packages{} - s := gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` + cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` type Message { id: ID } type Query { @@ -58,10 +61,7 @@ func createBinder(cfg Config) (*Binder, *ast.Schema) { } `}) - b, err := cfg.NewBinder(s) - if err != nil { - panic(err) - } + b := cfg.NewBinder() - return b, s + return b, cfg.Schema } diff --git a/codegen/config/config.go b/codegen/config/config.go index dcd4d3478c..10d95aac9f 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -9,8 +9,6 @@ import ( "sort" "strings" - "golang.org/x/tools/go/packages" - "github.com/99designs/gqlgen/internal/code" "github.com/pkg/errors" "github.com/vektah/gqlparser" @@ -31,6 +29,8 @@ type Config struct { SkipValidation bool `yaml:"skip_validation,omitempty"` Federated bool `yaml:"federated,omitempty"` AdditionalSources []*ast.Source `yaml:"-"` + Packages *code.Packages `yaml:"-"` + Schema *ast.Schema `yaml:"-"` } var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"} @@ -42,6 +42,7 @@ func DefaultConfig() *Config { Model: PackageConfig{Filename: "models_gen.go"}, Exec: PackageConfig{Filename: "generated.go"}, Directives: map[string]DirectiveConfig{}, + Models: TypeMap{}, } } @@ -137,6 +138,113 @@ func LoadConfig(filename string) (*Config, error) { return config, nil } +func (c *Config) Init() error { + if c.Packages == nil { + c.Packages = &code.Packages{} + } + + if c.Schema == nil { + if err := c.LoadSchema(); err != nil { + return err + } + } + + err := c.injectTypesFromSchema() + if err != nil { + return err + } + + err = c.autobind() + if err != nil { + return err + } + + c.injectBuiltins() + + // prefetch all packages in one big packages.Load call + pkgs := []string{ + "github.com/99designs/gqlgen/graphql", + "github.com/99designs/gqlgen/graphql/introspection", + } + pkgs = append(pkgs, c.Models.ReferencedPackages()...) + pkgs = append(pkgs, c.AutoBind...) + c.Packages.LoadAll(pkgs...) + + // check everything is valid on the way out + err = c.check() + if err != nil { + return err + } + + return nil +} + +func (c *Config) injectTypesFromSchema() error { + c.Directives["goModel"] = DirectiveConfig{ + SkipRuntime: true, + } + + c.Directives["goField"] = DirectiveConfig{ + SkipRuntime: true, + } + + for _, schemaType := range c.Schema.Types { + if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription { + continue + } + + if bd := schemaType.Directives.ForName("goModel"); bd != nil { + if ma := bd.Arguments.ForName("model"); ma != nil { + if mv, err := ma.Value.Value(nil); err == nil { + c.Models.Add(schemaType.Name, mv.(string)) + } + } + if ma := bd.Arguments.ForName("models"); ma != nil { + if mvs, err := ma.Value.Value(nil); err == nil { + for _, mv := range mvs.([]interface{}) { + c.Models.Add(schemaType.Name, mv.(string)) + } + } + } + } + + if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject { + for _, field := range schemaType.Fields { + if fd := field.Directives.ForName("goField"); fd != nil { + forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver + fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName + + if ra := fd.Arguments.ForName("forceResolver"); ra != nil { + if fr, err := ra.Value.Value(nil); err == nil { + forceResolver = fr.(bool) + } + } + + if na := fd.Arguments.ForName("name"); na != nil { + if fr, err := na.Value.Value(nil); err == nil { + fieldName = fr.(string) + } + } + + if c.Models[schemaType.Name].Fields == nil { + c.Models[schemaType.Name] = TypeMapEntry{ + Model: c.Models[schemaType.Name].Model, + Fields: map[string]TypeMapField{}, + } + } + + c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{ + FieldName: fieldName, + Resolver: forceResolver, + } + } + } + } + } + + return nil +} + type TypeMapEntry struct { Model StringList `yaml:"model"` Fields map[string]TypeMapField `yaml:"fields,omitempty"` @@ -177,7 +285,7 @@ func (a StringList) Has(file string) bool { return false } -func (c *Config) Check() error { +func (c *Config) check() error { if c.Models == nil { c.Models = TypeMap{} } @@ -339,24 +447,14 @@ func findCfgInDir(dir string) string { return "" } -func (c *Config) Autobind(s *ast.Schema) error { +func (c *Config) autobind() error { if len(c.AutoBind) == 0 { return nil } - ps, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | - packages.NeedFiles | - packages.NeedCompiledGoFiles | - packages.NeedImports | - packages.NeedTypes | - packages.NeedTypesSizes, - }, c.AutoBind...) - if err != nil { - return err - } + ps := c.Packages.LoadAll(c.AutoBind...) - for _, t := range s.Types { + for _, t := range c.Schema.Types { if c.Models.UserDefined(t.Name) { continue } @@ -393,7 +491,7 @@ func (c *Config) Autobind(s *ast.Schema) error { return nil } -func (c *Config) InjectBuiltins(s *ast.Schema) { +func (c *Config) injectBuiltins() { builtins := TypeMap{ "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}}, "__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}}, @@ -434,13 +532,21 @@ func (c *Config) InjectBuiltins(s *ast.Schema) { } for typeName, entry := range extraBuiltins { - if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar { + if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar { c.Models[typeName] = entry } } } -func (c *Config) LoadSchema() (*ast.Schema, error) { +func (c *Config) LoadSchema() error { + if c.Packages != nil { + c.Packages = &code.Packages{} + } + + if err := c.check(); err != nil { + return err + } + sources := append([]*ast.Source{}, c.AdditionalSources...) for _, filename := range c.SchemaFilename { filename = filepath.ToSlash(filename) @@ -456,9 +562,10 @@ func (c *Config) LoadSchema() (*ast.Schema, error) { schema, err := gqlparser.LoadSchema(sources...) if err != nil { - return nil, err + return err } - return schema, nil + c.Schema = schema + return nil } func abs(path string) string { diff --git a/codegen/config/config_test.go b/codegen/config/config_test.go index 9ff43f2bc2..292090294c 100644 --- a/codegen/config/config_test.go +++ b/codegen/config/config_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/99designs/gqlgen/internal/code" ) func TestLoadConfig(t *testing.T) { @@ -110,7 +112,7 @@ func TestConfigCheck(t *testing.T) { config, err := LoadConfig("testdata/cfg/conflictedPackages.yml") require.NoError(t, err) - err = config.Check() + err = config.check() require.EqualError(t, err, "exec and model define the same import path (github.com/99designs/gqlgen/codegen/config/generated) with different package names (graphql vs generated)") }) } @@ -122,14 +124,15 @@ func TestAutobinding(t *testing.T) { "github.com/99designs/gqlgen/example/chat", "github.com/99designs/gqlgen/example/scalars/model", }, + Packages: &code.Packages{}, } - s := gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` + cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` scalar Banned type Message { id: ID } `}) - require.NoError(t, cfg.Autobind(s)) + require.NoError(t, cfg.autobind()) require.Equal(t, "github.com/99designs/gqlgen/example/scalars/model.Banned", cfg.Models["Banned"].Model[0]) require.Equal(t, "github.com/99designs/gqlgen/example/chat.Message", cfg.Models["Message"].Model[0]) diff --git a/codegen/data.go b/codegen/data.go index d3b191f8f9..e30b33c83c 100644 --- a/codegen/data.go +++ b/codegen/data.go @@ -6,11 +6,9 @@ import ( "sort" "github.com/99designs/gqlgen/codegen/config" - "github.com/99designs/gqlgen/internal/code" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/formatter" - "golang.org/x/tools/go/packages" ) // Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement @@ -39,45 +37,15 @@ type builder struct { Directives map[string]*Directive } -type SchemaMutator interface { - MutateSchema(s *ast.Schema) error -} - -func BuildData(cfg *config.Config, plugins []SchemaMutator) (*Data, error) { +func BuildData(cfg *config.Config) (*Data, error) { b := builder{ Config: cfg, + Schema: cfg.Schema, } - var err error - b.Schema, err = cfg.LoadSchema() - if err != nil { - return nil, err - } - - err = cfg.Check() - if err != nil { - return nil, err - } - - err = cfg.Autobind(b.Schema) - if err != nil { - return nil, err - } - - cfg.InjectBuiltins(b.Schema) - - for _, p := range plugins { - err = p.MutateSchema(b.Schema) - if err != nil { - return nil, fmt.Errorf("error running MutateSchema: %v", err) - } - } - - b.Binder, err = b.Config.NewBinder(b.Schema) - if err != nil { - return nil, err - } + b.Binder = b.Config.NewBinder() + var err error b.Directives, err = b.buildDirectives() if err != nil { return nil, err @@ -90,12 +58,6 @@ func BuildData(cfg *config.Config, plugins []SchemaMutator) (*Data, error) { } } - pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) - if err != nil { - return nil, errors.Wrap(err, "loading failed") - } - code.RecordPackagesList(pkgs) - s := Data{ Config: cfg, Directives: dataDirectives, @@ -156,8 +118,9 @@ func BuildData(cfg *config.Config, plugins []SchemaMutator) (*Data, error) { if b.Binder.SawInvalid { // if we have a syntax error, show it - if len(b.Binder.PkgErrors) > 0 { - return nil, b.Binder.PkgErrors + err := cfg.Packages.Errors() + if len(err) > 0 { + return nil, err } // otherwise show a generic error message diff --git a/codegen/generate.go b/codegen/generate.go index eafa3f8743..f1ed2ca27b 100644 --- a/codegen/generate.go +++ b/codegen/generate.go @@ -11,5 +11,6 @@ func GenerateCode(data *Data) error { Data: data, RegionTags: true, GeneratedHeader: true, + Packages: data.Config.Packages, }) } diff --git a/codegen/templates/import.go b/codegen/templates/import.go index d5bd16a6a1..17bd96ab2e 100644 --- a/codegen/templates/import.go +++ b/codegen/templates/import.go @@ -16,8 +16,9 @@ type Import struct { } type Imports struct { - imports []*Import - destDir string + imports []*Import + destDir string + packages *code.Packages } func (i *Import) String() string { @@ -49,7 +50,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) { return "", nil } - name := code.NameForPackage(path) + name := s.packages.NameForPackage(path) var alias string if len(aliases) != 1 { alias = name @@ -94,7 +95,7 @@ func (s *Imports) Lookup(path string) string { } imp := &Import{ - Name: code.NameForPackage(path), + Name: s.packages.NameForPackage(path), Path: path, } s.imports = append(s.imports, imp) diff --git a/codegen/templates/import_test.go b/codegen/templates/import_test.go index 440b59147c..2e8dd5a89a 100644 --- a/codegen/templates/import_test.go +++ b/codegen/templates/import_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/99designs/gqlgen/internal/code" + "github.com/stretchr/testify/require" - "golang.org/x/tools/go/packages" ) func TestImports(t *testing.T) { @@ -18,20 +18,15 @@ func TestImports(t *testing.T) { bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar" mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch" - ps, err := packages.Load(nil, aBar, bBar, mismatch) - require.NoError(t, err) - - code.RecordPackagesList(ps) - t.Run("multiple lookups is ok", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar", a.Lookup(aBar)) }) t.Run("lookup by type", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar") typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil) @@ -40,7 +35,7 @@ func TestImports(t *testing.T) { }) t.Run("duplicates are decollisioned", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar1", a.Lookup(bBar)) @@ -51,13 +46,13 @@ func TestImports(t *testing.T) { }) t.Run("package name defined in code will be used", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} require.Equal(t, "turtles", a.Lookup(mismatch)) }) t.Run("string printing for import block", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} a.Lookup(aBar) a.Lookup(bBar) a.Lookup(mismatch) @@ -72,7 +67,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`, }) t.Run("aliased imports will not collide", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{destDir: wd, packages: &code.Packages{}} _, _ = a.Reserve(aBar, "abar") _, _ = a.Reserve(bBar, "bbar") diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index e872244253..b10dd68b4d 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -15,6 +15,8 @@ import ( "text/template" "unicode" + "github.com/99designs/gqlgen/internal/code" + "github.com/99designs/gqlgen/internal/imports" "github.com/pkg/errors" ) @@ -45,6 +47,9 @@ type Options struct { // Data will be passed to the template execution. Data interface{} Funcs template.FuncMap + + // Packages cache, you can find me on config.Config + Packages *code.Packages } // Render renders a gql plugin template from the given Options. Render is an @@ -55,7 +60,7 @@ func Render(cfg Options) error { if CurrentImports != nil { panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) } - CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)} + CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} // load path relative to calling source file _, callerFile, _, _ := runtime.Caller(1) @@ -148,7 +153,13 @@ func Render(cfg Options) error { } CurrentImports = nil - return write(cfg.Filename, result.Bytes()) + err = write(cfg.Filename, result.Bytes(), cfg.Packages) + if err != nil { + return err + } + + cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename))) + return nil } func center(width int, pad string, s string) string { @@ -556,13 +567,13 @@ func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { return buf, t.Execute(buf, tpldata) } -func write(filename string, b []byte) error { +func write(filename string, b []byte, packages *code.Packages) error { err := os.MkdirAll(filepath.Dir(filename), 0755) if err != nil { return errors.Wrap(err, "failed to create directory") } - formatted, err := imports.Prune(filename, b) + formatted, err := imports.Prune(filename, b, packages) if err != nil { fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) formatted = b diff --git a/codegen/templates/templates_test.go b/codegen/templates/templates_test.go index 98df68360e..b02698b0b6 100644 --- a/codegen/templates/templates_test.go +++ b/codegen/templates/templates_test.go @@ -5,6 +5,8 @@ import ( "os" "testing" + "github.com/99designs/gqlgen/internal/code" + "github.com/stretchr/testify/require" ) @@ -110,7 +112,7 @@ func TestTemplateOverride(t *testing.T) { } defer f.Close() defer os.RemoveAll(f.Name()) - err = Render(Options{Template: "hello", Filename: f.Name()}) + err = Render(Options{Template: "hello", Filename: f.Name(), Packages: &code.Packages{}}) if err != nil { t.Fatal(err) } diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index 83b5777055..ab0d95c494 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -1,6 +1,6 @@ schema: - "*.graphql" - +skip_validation: true exec: filename: generated.go model: diff --git a/internal/code/imports.go b/internal/code/imports.go index 10b325ba6f..e861a6eb5d 100644 --- a/internal/code/imports.go +++ b/internal/code/imports.go @@ -1,8 +1,6 @@ package code import ( - "errors" - "fmt" "go/build" "go/parser" "go/token" @@ -10,14 +8,8 @@ import ( "path/filepath" "regexp" "strings" - "sync" - - "golang.org/x/tools/go/packages" ) -var nameForPackageCacheLock sync.Mutex -var nameForPackageCache []*packages.Package - var gopaths []string func init() { @@ -108,33 +100,3 @@ func ImportPathForDir(dir string) (res string) { } var modregex = regexp.MustCompile("module (.*)\n") - -// RecordPackagesList records the list of packages to be used later by NameForPackage. -// It must be called exactly once during initialization, before NameForPackage is called. -func RecordPackagesList(newNameForPackageCache []*packages.Package) { - nameForPackageCache = newNameForPackageCache -} - -// NameForPackage returns the package name for a given import path. This can be really slow. -func NameForPackage(importPath string) string { - if importPath == "" { - panic(errors.New("import path can not be empty")) - } - if nameForPackageCache == nil { - panic(fmt.Errorf("NameForPackage called for %s before RecordPackagesList", importPath)) - } - nameForPackageCacheLock.Lock() - defer nameForPackageCacheLock.Unlock() - var p *packages.Package - for _, pkg := range nameForPackageCache { - if pkg.PkgPath == importPath { - p = pkg - break - } - } - - if p == nil || p.Name == "" { - return SanitizePackageName(filepath.Base(importPath)) - } - return p.Name -} diff --git a/internal/code/imports_test.go b/internal/code/imports_test.go index f7e6f24a5a..ec00982f3c 100644 --- a/internal/code/imports_test.go +++ b/internal/code/imports_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/tools/go/packages" ) func TestImportPathForDir(t *testing.T) { @@ -31,20 +30,6 @@ func TestImportPathForDir(t *testing.T) { } } -func TestNameForPackage(t *testing.T) { - testPkg1 := "github.com/99designs/gqlgen/api" - testPkg2 := "github.com/99designs/gqlgen/docs" - testPkg3 := "github.com" - ps, err := packages.Load(nil, testPkg1, testPkg2, testPkg3) - require.NoError(t, err) - RecordPackagesList(ps) - assert.Equal(t, "api", NameForPackage(testPkg1)) - - // does not contain go code, should still give a valid name - assert.Equal(t, "docs", NameForPackage(testPkg2)) - assert.Equal(t, "github_com", NameForPackage(testPkg3)) -} - func TestNameForDir(t *testing.T) { wd, err := os.Getwd() require.NoError(t, err) diff --git a/internal/code/packages.go b/internal/code/packages.go new file mode 100644 index 0000000000..b14c45ad27 --- /dev/null +++ b/internal/code/packages.go @@ -0,0 +1,173 @@ +package code + +import ( + "bytes" + "path/filepath" + + "github.com/pkg/errors" + "golang.org/x/tools/go/packages" +) + +var mode = packages.NeedName | + packages.NeedFiles | + packages.NeedImports | + packages.NeedTypes | + packages.NeedSyntax | + packages.NeedTypesInfo + +// Packages is a wrapper around x/tools/go/packages that maintains a (hopefully prewarmed) cache of packages +// that can be invalidated as writes are made and packages are known to change. +type Packages struct { + packages map[string]*packages.Package + importToName map[string]string + loadErrors []error + + numLoadCalls int // stupid test steam. ignore. + numNameCalls int // stupid test steam. ignore. +} + +// LoadAll will call packages.Load and return the package data for the given packages, +// but if the package already have been loaded it will return cached values instead. +func (p *Packages) LoadAll(importPaths ...string) []*packages.Package { + if p.packages == nil { + p.packages = map[string]*packages.Package{} + } + + missing := make([]string, 0, len(importPaths)) + for _, path := range importPaths { + if _, ok := p.packages[path]; ok { + continue + } + missing = append(missing, path) + } + + if len(missing) > 0 { + p.numLoadCalls++ + pkgs, err := packages.Load(&packages.Config{Mode: mode}, missing...) + if err != nil { + p.loadErrors = append(p.loadErrors, err) + } + + for _, pkg := range pkgs { + p.addToCache(pkg) + } + } + + res := make([]*packages.Package, 0, len(importPaths)) + for _, path := range importPaths { + res = append(res, p.packages[NormalizeVendor(path)]) + } + return res +} + +func (p *Packages) addToCache(pkg *packages.Package) { + imp := NormalizeVendor(pkg.PkgPath) + p.packages[imp] = pkg + for _, imp := range pkg.Imports { + if _, found := p.packages[NormalizeVendor(imp.PkgPath)]; !found { + p.addToCache(imp) + } + } +} + +// Load works the same as LoadAll, except a single package at a time. +func (p *Packages) Load(importPath string) *packages.Package { + pkgs := p.LoadAll(importPath) + if len(pkgs) == 0 { + return nil + } + return pkgs[0] +} + +// LoadWithTypes tries a standard load, which may not have enough type info (TypesInfo== nil) available if the imported package is a +// second order dependency. Fortunately this doesnt happen very often, so we can just issue a load when we detect it. +func (p *Packages) LoadWithTypes(importPath string) *packages.Package { + pkg := p.Load(importPath) + if pkg == nil || pkg.TypesInfo == nil { + p.numLoadCalls++ + pkgs, err := packages.Load(&packages.Config{Mode: mode}, importPath) + if err != nil { + p.loadErrors = append(p.loadErrors, err) + return nil + } + p.addToCache(pkgs[0]) + pkg = pkgs[0] + } + return pkg +} + +// NameForPackage looks up the package name from the package stanza in the go files at the given import path. +func (p *Packages) NameForPackage(importPath string) string { + if importPath == "" { + panic(errors.New("import path can not be empty")) + } + if p.importToName == nil { + p.importToName = map[string]string{} + } + + importPath = NormalizeVendor(importPath) + + // if its in the name cache use it + if name := p.importToName[importPath]; name != "" { + return name + } + + // otherwise we might have already loaded the full package data for it cached + pkg := p.packages[importPath] + + if pkg == nil { + // otherwise do a name only lookup for it but dont put it in the package cache. + p.numNameCalls++ + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, importPath) + if err != nil { + p.loadErrors = append(p.loadErrors, err) + } else { + pkg = pkgs[0] + } + } + + if pkg == nil || pkg.Name == "" { + return SanitizePackageName(filepath.Base(importPath)) + } + + p.importToName[importPath] = pkg.Name + + return pkg.Name +} + +// Evict removes a given package import path from the cache, along with any packages that depend on it. Further calls +// to Load will fetch it from disk. +func (p *Packages) Evict(importPath string) { + delete(p.packages, importPath) + + for _, pkg := range p.packages { + for _, imported := range pkg.Imports { + if imported.PkgPath == importPath { + p.Evict(pkg.PkgPath) + } + } + } +} + +// Errors returns any errors that were returned by Load, either from the call itself or any of the loaded packages. +func (p *Packages) Errors() PkgErrors { + var res []error //nolint:prealloc + res = append(res, p.loadErrors...) + for _, pkg := range p.packages { + for _, err := range pkg.Errors { + res = append(res, err) + } + } + return res +} + +type PkgErrors []error + +func (p PkgErrors) Error() string { + var b bytes.Buffer + b.WriteString("packages.Load: ") + for _, e := range p { + b.WriteString(e.Error() + "\n") + } + return b.String() +} diff --git a/internal/code/packages_test.go b/internal/code/packages_test.go new file mode 100644 index 0000000000..2fbf780c55 --- /dev/null +++ b/internal/code/packages_test.go @@ -0,0 +1,66 @@ +package code + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPackages(t *testing.T) { + t.Run("name for existing package does not load again", func(t *testing.T) { + p := initialState(t) + require.Equal(t, "a", p.NameForPackage("github.com/99designs/gqlgen/internal/code/testdata/a")) + require.Equal(t, 1, p.numLoadCalls) + }) + + t.Run("name for unknown package makes name only load", func(t *testing.T) { + p := initialState(t) + require.Equal(t, "c", p.NameForPackage("github.com/99designs/gqlgen/internal/code/testdata/c")) + require.Equal(t, 1, p.numLoadCalls) + require.Equal(t, 1, p.numNameCalls) + }) + + t.Run("evicting a package causes it to load again", func(t *testing.T) { + p := initialState(t) + p.Evict("github.com/99designs/gqlgen/internal/code/testdata/b") + require.Equal(t, "a", p.Load("github.com/99designs/gqlgen/internal/code/testdata/a").Name) + require.Equal(t, 1, p.numLoadCalls) + require.Equal(t, "b", p.Load("github.com/99designs/gqlgen/internal/code/testdata/b").Name) + require.Equal(t, 2, p.numLoadCalls) + }) + + t.Run("evicting a package also evicts its dependencies", func(t *testing.T) { + p := initialState(t) + p.Evict("github.com/99designs/gqlgen/internal/code/testdata/a") + require.Equal(t, "a", p.Load("github.com/99designs/gqlgen/internal/code/testdata/a").Name) + require.Equal(t, 2, p.numLoadCalls) + require.Equal(t, "b", p.Load("github.com/99designs/gqlgen/internal/code/testdata/b").Name) + require.Equal(t, 3, p.numLoadCalls) + }) +} + +func TestNameForPackage(t *testing.T) { + var p Packages + + assert.Equal(t, "api", p.NameForPackage("github.com/99designs/gqlgen/api")) + + // does not contain go code, should still give a valid name + assert.Equal(t, "docs", p.NameForPackage("github.com/99designs/gqlgen/docs")) + assert.Equal(t, "github_com", p.NameForPackage("github.com")) +} + +func initialState(t *testing.T) *Packages { + p := &Packages{} + pkgs := p.LoadAll( + "github.com/99designs/gqlgen/internal/code/testdata/a", + "github.com/99designs/gqlgen/internal/code/testdata/b", + ) + require.Nil(t, p.Errors()) + + require.Equal(t, 1, p.numLoadCalls) + require.Equal(t, 0, p.numNameCalls) + require.Equal(t, "a", pkgs[0].Name) + require.Equal(t, "b", pkgs[1].Name) + return p +} diff --git a/internal/code/testdata/a/a.go b/internal/code/testdata/a/a.go new file mode 100644 index 0000000000..bc4bece9a7 --- /dev/null +++ b/internal/code/testdata/a/a.go @@ -0,0 +1,3 @@ +package a + +var A = "A" diff --git a/internal/code/testdata/b/b.go b/internal/code/testdata/b/b.go new file mode 100644 index 0000000000..ecf84ccf62 --- /dev/null +++ b/internal/code/testdata/b/b.go @@ -0,0 +1,5 @@ +package b + +import "github.com/99designs/gqlgen/internal/code/testdata/a" + +var B = a.A + " B" diff --git a/internal/code/testdata/c/c.go b/internal/code/testdata/c/c.go new file mode 100644 index 0000000000..a6c7018848 --- /dev/null +++ b/internal/code/testdata/c/c.go @@ -0,0 +1,7 @@ +package c + +import ( + "github.com/99designs/gqlgen/internal/code/testdata/b" +) + +var C = b.B + " C" diff --git a/internal/imports/prune.go b/internal/imports/prune.go index 27ac94ac0f..d42a415791 100644 --- a/internal/imports/prune.go +++ b/internal/imports/prune.go @@ -24,7 +24,7 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor { } // Prune removes any unused imports -func Prune(filename string, src []byte) ([]byte, error) { +func Prune(filename string, src []byte, packages *code.Packages) ([]byte, error) { fset := token.NewFileSet() file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors) @@ -32,7 +32,7 @@ func Prune(filename string, src []byte) ([]byte, error) { return nil, err } - unused := getUnusedImports(file) + unused := getUnusedImports(file, packages) for ipath, name := range unused { astutil.DeleteNamedImport(fset, file, name, ipath) } @@ -46,7 +46,7 @@ func Prune(filename string, src []byte) ([]byte, error) { return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8}) } -func getUnusedImports(file ast.Node) map[string]string { +func getUnusedImports(file ast.Node, packages *code.Packages) map[string]string { imported := map[string]*ast.ImportSpec{} used := map[string]bool{} @@ -65,7 +65,7 @@ func getUnusedImports(file ast.Node) map[string]string { break } - local := code.NameForPackage(ipath) + local := packages.NameForPackage(ipath) imported[local] = v case *ast.SelectorExpr: diff --git a/internal/imports/prune_test.go b/internal/imports/prune_test.go index 5f1563e510..a50220d757 100644 --- a/internal/imports/prune_test.go +++ b/internal/imports/prune_test.go @@ -6,14 +6,12 @@ import ( "github.com/99designs/gqlgen/internal/code" "github.com/stretchr/testify/require" - "golang.org/x/tools/go/packages" ) func TestPrune(t *testing.T) { // prime the packages cache so that it's not considered uninitialized - code.RecordPackagesList([]*packages.Package{}) - b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go")) + b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), &code.Packages{}) require.NoError(t, err) require.Equal(t, string(mustReadFile("testdata/unused.expected.go")), string(b)) } diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index f7142111ee..37994f5b35 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -173,7 +173,7 @@ func (f *federation) MutateSchema(s *ast.Schema) error { func (f *federation) getSource(builtin bool) *ast.Source { return &ast.Source{ Name: "federation.graphql", - Input: `# Declarations as required by the federation spec + Input: `# Declarations as required by the federation spec # See: https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ scalar _Any @@ -244,15 +244,18 @@ func (f *federation) GenerateCode(data *codegen.Data) error { Filename: "service.go", Data: f, GeneratedHeader: true, + Packages: data.Config.Packages, }) } func (f *federation) setEntities(cfg *config.Config) { - schema, err := cfg.LoadSchema() + // crazy hack to get our injected code in so everything compiles, so we can generate the entity map + // so we can reload the full schema. + err := cfg.LoadSchema() if err != nil { panic(err) } - for _, schemaType := range schema.Types { + for _, schemaType := range cfg.Schema.Types { if schemaType.Kind == ast.Object { dir := schemaType.Directives.ForName("key") // TODO: interfaces if dir != nil { diff --git a/plugin/federation/federation_test.go b/plugin/federation/federation_test.go index b9876fdcfc..c6b0b37bab 100644 --- a/plugin/federation/federation_test.go +++ b/plugin/federation/federation_test.go @@ -47,11 +47,15 @@ func TestGetSDL(t *testing.T) { func TestMutateConfig(t *testing.T) { cfg, err := config.LoadConfig("test_data/gqlgen.yml") require.NoError(t, err) - require.NoError(t, cfg.Check()) f := &federation{} - err = f.MutateConfig(cfg) - require.NoError(t, err) + f.InjectSources(cfg) + + require.NoError(t, cfg.LoadSchema()) + require.NoError(t, f.MutateSchema(cfg.Schema)) + require.NoError(t, cfg.Init()) + require.NoError(t, f.MutateConfig(cfg)) + } func TestInjectSourcesNoKey(t *testing.T) { @@ -75,7 +79,7 @@ func TestGetSDLNoKey(t *testing.T) { func TestMutateConfigNoKey(t *testing.T) { cfg, err := config.LoadConfig("test_data/nokey.yml") require.NoError(t, err) - require.NoError(t, cfg.Check()) + require.NoError(t, cfg.Init()) f := &federation{} err = f.MutateConfig(cfg) diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index 3e689c02c8..1395901c50 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -7,11 +7,8 @@ import ( "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" - "github.com/99designs/gqlgen/internal/code" "github.com/99designs/gqlgen/plugin" - "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" - "golang.org/x/tools/go/packages" ) type BuildMutateHook = func(b *ModelBuild) *ModelBuild @@ -75,29 +72,14 @@ func (m *Plugin) Name() string { } func (m *Plugin) MutateConfig(cfg *config.Config) error { - schema, err := cfg.LoadSchema() - if err != nil { - return err - } - - err = cfg.Autobind(schema) - if err != nil { - return err - } - - cfg.InjectBuiltins(schema) - - binder, err := cfg.NewBinder(schema) - if err != nil { - return err - } + binder := cfg.NewBinder() b := &ModelBuild{ PackageName: cfg.Model.Package, } var hasEntity bool - for _, schemaType := range schema.Types { + for _, schemaType := range cfg.Schema.Types { if cfg.Models.UserDefined(schemaType.Name) { continue } @@ -117,14 +99,14 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { b.Interfaces = append(b.Interfaces, it) case ast.Object, ast.InputObject: - if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription { + if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription { continue } it := &Object{ Description: schemaType.Description, Name: schemaType.Name, } - for _, implementor := range schema.GetImplements(schemaType) { + for _, implementor := range cfg.Schema.GetImplements(schemaType) { it.Implements = append(it.Implements, implementor.Name) } @@ -133,9 +115,10 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { } for _, field := range schemaType.Fields { var typ types.Type - fieldDef := schema.Types[field.Type.Name()] + fieldDef := cfg.Schema.Types[field.Type.Name()] if cfg.Models.UserDefined(field.Type.Name()) { + var err error typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) if err != nil { return err @@ -249,17 +232,12 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { b = m.MutateHook(b) } - pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) - if err != nil { - return errors.Wrap(err, "loading failed") - } - code.RecordPackagesList(pkgs) - return templates.Render(templates.Options{ PackageName: cfg.Model.Package, Filename: cfg.Model.Filename, Data: b, GeneratedHeader: true, + Packages: cfg.Packages, }) } diff --git a/plugin/modelgen/models_test.go b/plugin/modelgen/models_test.go index 998acc2fcd..64cbff9e20 100644 --- a/plugin/modelgen/models_test.go +++ b/plugin/modelgen/models_test.go @@ -14,7 +14,7 @@ import ( func TestModelGeneration(t *testing.T) { cfg, err := config.LoadConfig("testdata/gqlgen.yml") require.NoError(t, err) - require.NoError(t, cfg.Check()) + require.NoError(t, cfg.Init()) p := Plugin{ MutateHook: mutateHook, } diff --git a/plugin/plugin.go b/plugin/plugin.go index 42d1adc5d6..3a18745370 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -5,6 +5,7 @@ package plugin import ( "github.com/99designs/gqlgen/codegen" "github.com/99designs/gqlgen/codegen/config" + "github.com/vektah/gqlparser/ast" ) type Plugin interface { @@ -22,3 +23,7 @@ type CodeGenerator interface { type SourcesInjector interface { InjectSources(cfg *config.Config) } + +type SchemaMutator interface { + MutateSchema(s *ast.Schema) error +} diff --git a/plugin/resolvergen/resolver.go b/plugin/resolvergen/resolver.go index 5fda546cb4..7276b7d636 100644 --- a/plugin/resolvergen/resolver.go +++ b/plugin/resolvergen/resolver.go @@ -74,6 +74,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error { PackageDoc: `// THIS CODE IS A STARTING POINT ONLY. IT WILL NOT BE UPDATED WITH SCHEMA CHANGES.`, Filename: data.Config.Resolver.Filename, Data: resolverBuild, + Packages: data.Config.Packages, }) } @@ -136,6 +137,7 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { // will be copied through when generating and any unknown code will be moved to the end.`, Filename: filename, Data: resolverBuild, + Packages: data.Config.Packages, }) if err != nil { return err @@ -152,6 +154,7 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { Template: `type {{.}} struct {}`, Filename: data.Config.Resolver.Filename, Data: data.Config.Resolver.Type, + Packages: data.Config.Packages, }) if err != nil { return err diff --git a/plugin/resolvergen/resolver_test.go b/plugin/resolvergen/resolver_test.go index 07cc599538..a994591193 100644 --- a/plugin/resolvergen/resolver_test.go +++ b/plugin/resolvergen/resolver_test.go @@ -19,7 +19,9 @@ func TestLayoutSingleFile(t *testing.T) { require.NoError(t, err) p := Plugin{} - data, err := codegen.BuildData(cfg, nil) + require.NoError(t, cfg.Init()) + + data, err := codegen.BuildData(cfg) if err != nil { panic(err) } @@ -35,7 +37,9 @@ func TestLayoutFollowSchema(t *testing.T) { require.NoError(t, err) p := Plugin{} - data, err := codegen.BuildData(cfg, nil) + require.NoError(t, cfg.Init()) + + data, err := codegen.BuildData(cfg) if err != nil { panic(err) } diff --git a/plugin/schemaconfig/schemaconfig.go b/plugin/schemaconfig/schemaconfig.go deleted file mode 100644 index 4fea2ebaeb..0000000000 --- a/plugin/schemaconfig/schemaconfig.go +++ /dev/null @@ -1,89 +0,0 @@ -package schemaconfig - -import ( - "github.com/99designs/gqlgen/codegen/config" - "github.com/99designs/gqlgen/plugin" - "github.com/vektah/gqlparser/ast" -) - -func New() plugin.Plugin { - return &Plugin{} -} - -type Plugin struct{} - -var _ plugin.ConfigMutator = &Plugin{} - -func (m *Plugin) Name() string { - return "schemaconfig" -} - -func (m *Plugin) MutateConfig(cfg *config.Config) error { - schema, err := cfg.LoadSchema() - if err != nil { - return err - } - - cfg.Directives["goModel"] = config.DirectiveConfig{ - SkipRuntime: true, - } - - cfg.Directives["goField"] = config.DirectiveConfig{ - SkipRuntime: true, - } - - for _, schemaType := range schema.Types { - if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription { - continue - } - - if bd := schemaType.Directives.ForName("goModel"); bd != nil { - if ma := bd.Arguments.ForName("model"); ma != nil { - if mv, err := ma.Value.Value(nil); err == nil { - cfg.Models.Add(schemaType.Name, mv.(string)) - } - } - if ma := bd.Arguments.ForName("models"); ma != nil { - if mvs, err := ma.Value.Value(nil); err == nil { - for _, mv := range mvs.([]interface{}) { - cfg.Models.Add(schemaType.Name, mv.(string)) - } - } - } - } - - if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject { - for _, field := range schemaType.Fields { - if fd := field.Directives.ForName("goField"); fd != nil { - forceResolver := cfg.Models[schemaType.Name].Fields[field.Name].Resolver - fieldName := cfg.Models[schemaType.Name].Fields[field.Name].FieldName - - if ra := fd.Arguments.ForName("forceResolver"); ra != nil { - if fr, err := ra.Value.Value(nil); err == nil { - forceResolver = fr.(bool) - } - } - - if na := fd.Arguments.ForName("name"); na != nil { - if fr, err := na.Value.Value(nil); err == nil { - fieldName = fr.(string) - } - } - - if cfg.Models[schemaType.Name].Fields == nil { - cfg.Models[schemaType.Name] = config.TypeMapEntry{ - Model: cfg.Models[schemaType.Name].Model, - Fields: map[string]config.TypeMapField{}, - } - } - - cfg.Models[schemaType.Name].Fields[field.Name] = config.TypeMapField{ - FieldName: fieldName, - Resolver: forceResolver, - } - } - } - } - } - return nil -} diff --git a/plugin/servergen/server.go b/plugin/servergen/server.go index 22289c0254..029c9ae398 100644 --- a/plugin/servergen/server.go +++ b/plugin/servergen/server.go @@ -34,6 +34,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { PackageName: "main", Filename: m.filename, Data: serverBuild, + Packages: data.Config.Packages, }) } diff --git a/plugin/stubgen/stubs.go b/plugin/stubgen/stubs.go index af5171b4cf..6540d34b07 100644 --- a/plugin/stubgen/stubs.go +++ b/plugin/stubgen/stubs.go @@ -48,6 +48,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { TypeName: m.typeName, }, GeneratedHeader: true, + Packages: data.Config.Packages, }) }