Skip to content

Commit

Permalink
Go modules support, migrate from golang.org/x/tools loader to packages
Browse files Browse the repository at this point in the history
This commit adds Go modules support by switching to the new package that
supports them in the golang tools repository.

Closes BurntSushi/go-sumtype#8
  • Loading branch information
kujenga authored and BurntSushi committed Feb 25, 2019
1 parent 0c33c2c commit e93e76e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 61 deletions.
22 changes: 10 additions & 12 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"sort"
"strings"

"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
)

// inexhaustiveError is returned from check for each occurrence of inexhaustive
Expand Down Expand Up @@ -38,15 +38,15 @@ func (e inexhaustiveError) Names() []string {

// check does exhaustiveness checking for the given sum type definitions in the
// given package. Every instance of inexhaustive case analysis is returned.
func check(prog *loader.Program, defs []sumTypeDef, pkg *loader.PackageInfo) []error {
func check(pkg *packages.Package, defs []sumTypeDef) []error {
var errs []error
for _, astfile := range pkg.Files {
for _, astfile := range pkg.Syntax {
ast.Inspect(astfile, func(n ast.Node) bool {
swtch, ok := n.(*ast.TypeSwitchStmt)
if !ok {
return true
}
if err := checkSwitch(prog, pkg, defs, swtch); err != nil {
if err := checkSwitch(pkg, defs, swtch); err != nil {
errs = append(errs, err)
}
return true
Expand All @@ -63,15 +63,14 @@ func check(prog *loader.Program, defs []sumTypeDef, pkg *loader.PackageInfo) []e
// Note that if the type switch contains a non-panicing default case, then
// exhaustiveness checks are disabled.
func checkSwitch(
prog *loader.Program,
pkg *loader.PackageInfo,
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) error {
def, missing := missingVariantsInSwitch(prog, pkg, defs, swtch)
def, missing := missingVariantsInSwitch(pkg, defs, swtch)
if len(missing) > 0 {
return inexhaustiveError{
Pos: prog.Fset.Position(swtch.Pos()),
Pos: pkg.Fset.Position(swtch.Pos()),
Def: *def,
Missing: missing,
}
Expand All @@ -84,13 +83,12 @@ func checkSwitch(
// returned. (If no sum type definition could be found, then no exhaustiveness
// checks are performed, and therefore, no missing variants are returned.)
func missingVariantsInSwitch(
prog *loader.Program,
pkg *loader.PackageInfo,
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypeOf(asserted)
ty := pkg.TypesInfo.TypeOf(asserted)
def := findDef(defs, ty)
if def == nil {
// We couldn't find a corresponding sum type, so there's
Expand All @@ -104,7 +102,7 @@ func missingVariantsInSwitch(
}
var variantTypes []types.Type
for _, expr := range variantExprs {
variantTypes = append(variantTypes, pkg.TypeOf(expr))
variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
}
return def, def.missing(variantTypes)
}
Expand Down
32 changes: 16 additions & 16 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func main() {
}
}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand Down Expand Up @@ -61,10 +61,10 @@ func main() {
}
}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand Down Expand Up @@ -95,10 +95,10 @@ func main() {
}
}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand Down Expand Up @@ -129,10 +129,10 @@ func main() {
}
}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
assert.Len(t, errs, 0)
}

Expand Down Expand Up @@ -160,10 +160,10 @@ func main() {
}
}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
assert.Len(t, errs, 0)
}

Expand All @@ -179,10 +179,10 @@ type T interface {}
func main() {}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand All @@ -199,10 +199,10 @@ package main
func main() {}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand All @@ -221,10 +221,10 @@ type T struct {}
func main() {}
`
tmpdir, prog := setupPackage(t, code)
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := run(prog)
errs := run(pkgs)
if !assert.Len(t, errs, 1) {
t.FailNow()
}
Expand Down
13 changes: 6 additions & 7 deletions decl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"path/filepath"
"regexp"

"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
)

// sumTypeDecl is a declaration of a sum type in a Go source file.
type sumTypeDecl struct {
// The package path that contains this decl.
PackageInfo *loader.PackageInfo
Package *packages.Package
// The type named by this decl.
TypeName string
// The file path where this declaration was found.
Expand All @@ -31,11 +31,10 @@ func (d sumTypeDecl) Location() string {

// findSumTypeDecls searches every package given for sum type declarations of
// the form `go-sumtype:decl ...`.
func findSumTypeDecls(prog *loader.Program) ([]sumTypeDecl, error) {
func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) {
var decls []sumTypeDecl
for _, pkginfo := range prog.InitialPackages() {
for _, astfile := range pkginfo.Files {
filename := prog.Fset.Position(astfile.Package).Filename
for _, pkg := range pkgs {
for _, filename := range pkg.CompiledGoFiles {
if filepath.Base(filename) == "C" {
// ignore (fake?) cgo files
continue
Expand All @@ -45,7 +44,7 @@ func findSumTypeDecls(prog *loader.Program) ([]sumTypeDecl, error) {
return nil, err
}
for i := range fileDecls {
fileDecls[i].PackageInfo = pkginfo
fileDecls[i].Package = pkg
}
decls = append(decls, fileDecls...)
}
Expand Down
6 changes: 2 additions & 4 deletions def.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"fmt"
"go/ast"
"go/types"

"golang.org/x/tools/go/loader"
)

// unsealedError corresponds to a declared sum type whose interface is not
Expand Down Expand Up @@ -53,11 +51,11 @@ type sumTypeDef struct {
// findSumTypeDefs attempts to find a Go type definition for each of the given
// sum type declarations. If no such sum type definition could be found for
// any of the given declarations, then an error is returned.
func findSumTypeDefs(prog *loader.Program, decls []sumTypeDecl) ([]sumTypeDef, []error) {
func findSumTypeDefs(decls []sumTypeDecl) ([]sumTypeDef, []error) {
var defs []sumTypeDef
var errs []error
for _, decl := range decls {
def, err := newSumTypeDef(decl.PackageInfo.Pkg, decl)
def, err := newSumTypeDef(decl.Package.Types, decl)
if err != nil {
errs = append(errs, err)
continue
Expand Down
8 changes: 4 additions & 4 deletions help_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"path/filepath"
"testing"

"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
)

func setupPackage(t *testing.T, code string) (string, *loader.Program) {
func setupPackages(t *testing.T, code string) (string, []*packages.Package) {
tmpdir, err := ioutil.TempDir("", "go-test-sumtype-")
if err != nil {
t.Fatal(err)
Expand All @@ -18,11 +18,11 @@ func setupPackage(t *testing.T, code string) (string, *loader.Program) {
if err := ioutil.WriteFile(srcPath, []byte(code), 0666); err != nil {
t.Fatal(err)
}
prog, err := tycheckAll([]string{srcPath})
pkgs, err := tycheckAll([]string{srcPath})
if err != nil {
t.Fatal(err)
}
return tmpdir, prog
return tmpdir, pkgs
}

func teardownPackage(t *testing.T, dir string) {
Expand Down
33 changes: 15 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
package main

import (
"go/ast"
"log"
"os"
"strings"

"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
)

func main() {
log.SetFlags(0)
if len(os.Args) < 2 {
// TODO: Switch this to use golang.org/x/tools/go/packages.
log.Fatalf("Usage: go-sumtype <args>\n%s", loader.FromArgsUsage)
}
pkgpaths := os.Args[1:]
prog, err := tycheckAll(pkgpaths)
args := os.Args[1:]
pkgs, err := tycheckAll(args)
if err != nil {
log.Fatal(err)
}
if errs := run(prog); len(errs) > 0 {
if errs := run(pkgs); len(errs) > 0 {
var list []string
for _, err := range errs {
list = append(list, err.Error())
Expand All @@ -28,39 +29,35 @@ func main() {
}
}

func run(prog *loader.Program) []error {
func run(pkgs []*packages.Package) []error {
var errs []error

decls, err := findSumTypeDecls(prog)
decls, err := findSumTypeDecls(pkgs)
if err != nil {
return []error{err}
}

defs, defErrs := findSumTypeDefs(prog, decls)
defs, defErrs := findSumTypeDefs(decls)
errs = append(errs, defErrs...)
if len(defs) == 0 {
return errs
}

for _, pkg := range prog.InitialPackages() {
if pkgErrs := check(prog, defs, pkg); pkgErrs != nil {
for _, pkg := range pkgs {
if pkgErrs := check(pkg, defs); pkgErrs != nil {
errs = append(errs, pkgErrs...)
}
}
return errs
}

func tycheckAll(pkgpaths []string) (*loader.Program, error) {
conf := &loader.Config{
AfterTypeCheck: func(info *loader.PackageInfo, files []*ast.File) {
},
func tycheckAll(args []string) ([]*packages.Package, error) {
conf := &packages.Config{
Mode: packages.LoadSyntax,
}
if _, err := conf.FromArgs(pkgpaths, true); err != nil {
return nil, err
}
prog, err := conf.Load()
pkgs, err := packages.Load(conf, args...)
if err != nil {
return nil, err
}
return prog, nil
return pkgs, nil
}

0 comments on commit e93e76e

Please sign in to comment.