/
import_bag.go
234 lines (197 loc) · 5.99 KB
/
import_bag.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
package common
import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"path"
"sort"
"strings"
)
// PackageSpecifier is a string that represents the name that will be used to refer to
// exported functions and variables from a package.
type PackageSpecifier string
// PackagePath is a string that refers to the location of Go package.
type PackagePath string
// ImportBag captures all of the imports in a Go source file, and attempts
// to ease the process of working with them.
type ImportBag struct {
bySpec map[PackageSpecifier]PackagePath
blankIdent map[PackagePath]struct{}
localIdent map[PackagePath]struct{}
}
// NewImportBag instantiates an empty ImportBag.
func NewImportBag() *ImportBag {
return &ImportBag{
bySpec: make(map[PackageSpecifier]PackagePath),
blankIdent: make(map[PackagePath]struct{}),
localIdent: make(map[PackagePath]struct{}),
}
}
// NewImportBagFromFile reads a Go source file, finds all imports,
// and returns them as an instantiated ImportBag.
func NewImportBagFromFile(filepath string) (*ImportBag, error) {
f, err := parser.ParseFile(token.NewFileSet(), filepath, nil, parser.ImportsOnly)
if err != nil {
return nil, err
}
ib := NewImportBag()
for _, spec := range f.Imports {
pkgPath := PackagePath(strings.Trim(spec.Path.Value, `"`))
if spec.Name == nil {
ib.AddImport(pkgPath)
} else {
ib.AddImportWithSpecifier(pkgPath, PackageSpecifier(spec.Name.Name))
}
}
return ib, nil
}
// AddImport includes a package, and returns the name that was selected to
// be the specified for working with this path. It first attempts to use the
// package name as the specifier. Should that cause a conflict, it determines
// a unique name to be used as the specifier.
//
// If the path provided has already been imported, the existing name for it
// is returned, but err is non-nil.
func (ib *ImportBag) AddImport(pkgPath PackagePath) PackageSpecifier {
spec, err := FindSpecifier(pkgPath)
if err != nil {
spec = PackageSpecifier(path.Base(string(pkgPath)))
}
specLen := len(spec)
suffix := uint(1)
for {
err = ib.AddImportWithSpecifier(pkgPath, spec)
if err == nil {
break
} else if err != ErrDuplicateImport {
panic(err)
}
spec = PackageSpecifier(fmt.Sprintf("%s%d", spec[:specLen], suffix))
suffix++
}
return spec
}
// ErrDuplicateImport is the error that will be returned when two packages are both requested
// to be imported using the same specifier.
var ErrDuplicateImport = errors.New("specifier already in use in ImportBag")
// ErrMultipleLocalImport is the error that will be returned when the same package has been imported
// to the specifer "." more than once.
var ErrMultipleLocalImport = errors.New("package already imported into the local namespace")
// AddImportWithSpecifier will add an import with a given name. If it would lead
// to conflicting package specifiers, it returns an error.
func (ib *ImportBag) AddImportWithSpecifier(pkgPath PackagePath, specifier PackageSpecifier) error {
if specifier == "_" {
ib.blankIdent[pkgPath] = struct{}{}
return nil
}
if specifier == "." {
if _, ok := ib.localIdent[pkgPath]; ok {
return ErrMultipleLocalImport
}
ib.localIdent[pkgPath] = struct{}{}
return nil
}
if impPath, ok := ib.bySpec[specifier]; ok && pkgPath != impPath {
return ErrDuplicateImport
}
ib.bySpec[specifier] = pkgPath
return nil
}
// FindSpecifier finds the specifier assocatied with a particular package.
//
// If the package was not imported, the empty string and false are returned.
//
// If multiple specifiers are assigned to the package, one is returned at
// random.
//
// If the same package is imported with a named specifier, and the blank
// identifier, the name is returned.
func (ib ImportBag) FindSpecifier(pkgPath PackagePath) (PackageSpecifier, bool) {
for k, v := range ib.bySpec {
if v == pkgPath {
return k, true
}
}
if _, ok := ib.blankIdent[pkgPath]; ok {
return "_", true
}
if _, ok := ib.localIdent[pkgPath]; ok {
return ".", true
}
return "", false
}
// List returns each import statement as a slice of strings sorted alphabetically by
// their import paths.
func (ib *ImportBag) List() []string {
specs := ib.ListAsImportSpec()
retval := make([]string, len(specs))
builder := bytes.NewBuffer([]byte{})
for i, s := range specs {
if s.Name != nil {
builder.WriteString(s.Name.Name)
builder.WriteRune(' ')
}
builder.WriteString(s.Path.Value)
retval[i] = builder.String()
builder.Reset()
}
return retval
}
// ListAsImportSpec returns the imports from the ImportBag as a slice of ImportSpecs
// sorted alphabetically by their import paths.
func (ib *ImportBag) ListAsImportSpec() []*ast.ImportSpec {
retval := make([]*ast.ImportSpec, 0, len(ib.bySpec)+len(ib.localIdent)+len(ib.blankIdent))
getLit := func(pkgPath PackagePath) *ast.BasicLit {
return &ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", string(pkgPath)),
}
}
for k, v := range ib.bySpec {
var name *ast.Ident
if path.Base(string(v)) != string(k) {
name = ast.NewIdent(string(k))
}
retval = append(retval, &ast.ImportSpec{
Name: name,
Path: getLit(v),
})
}
for s := range ib.localIdent {
retval = append(retval, &ast.ImportSpec{
Name: ast.NewIdent("."),
Path: getLit(s),
})
}
for s := range ib.blankIdent {
retval = append(retval, &ast.ImportSpec{
Name: ast.NewIdent("_"),
Path: getLit(s),
})
}
sort.Slice(retval, func(i, j int) bool {
return strings.Compare(retval[i].Path.Value, retval[j].Path.Value) < 0
})
return retval
}
var impFinder = importer.Default()
// FindSpecifier finds the name of a package by loading it in from GOPATH
// or a vendor folder.
func FindSpecifier(pkgPath PackagePath) (PackageSpecifier, error) {
var pkg *types.Package
var err error
if cast, ok := impFinder.(types.ImporterFrom); ok {
pkg, err = cast.ImportFrom(string(pkgPath), "", 0)
} else {
pkg, err = impFinder.Import(string(pkgPath))
}
if err != nil {
return "", err
}
return PackageSpecifier(pkg.Name()), nil
}