forked from uber-go/nilaway
/
util.go
481 lines (433 loc) · 14.8 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package util implements utility functions for AST and types.
package util
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"regexp"
"strings"
"github.com/ZhongsJie/nilaway/config"
"golang.org/x/tools/go/analysis"
)
// ErrorType is the type of the builtin "error" interface.
var ErrorType = types.Universe.Lookup("error").Type()
// BoolType is the type of the builtin "bool" interface.
var BoolType = types.Universe.Lookup("bool").Type()
// BuiltinLen is the builtin "len" function object.
var BuiltinLen = types.Universe.Lookup("len")
// TypeIsDeep checks if a type is an expression that directly admits a deep nilability annotation - deep
// nilability annotations on all other types are ignored
func TypeIsDeep(t types.Type) bool {
_, isDeep := TypeAsDeepType(t)
return isDeep
}
// TypeAsDeepType checks if a type is an expression that directly admits a deep nilability annotation,
// returning true as its boolean param if so, along with the element type as its `types.Type` param
// nilable(result 0)
func TypeAsDeepType(t types.Type) (types.Type, bool) {
switch t := t.(type) {
case *types.Slice:
return t.Elem(), true
case *types.Array:
return t.Elem(), true
case *types.Map:
return t.Elem(), true
case *types.Chan:
return t.Elem(), true
case *types.Pointer:
return t.Elem(), true
}
return nil, false
}
// TypeIsSlice returns true if `t` is of slice type
func TypeIsSlice(t types.Type) bool {
switch t.(type) {
case *types.Slice:
return true
default:
return false
}
}
// TypeIsDeeplyArray returns true if `t` is of array type, including
// transitively through Named types
func TypeIsDeeplyArray(t types.Type) bool {
switch tt := UnwrapPtr(t).(type) {
case *types.Array:
return true
case *types.Named:
return TypeIsDeeplyArray(tt.Underlying())
}
return false
}
// TypeIsDeeplySlice returns true if `t` is of slice type, including
// transitively through Named types
func TypeIsDeeplySlice(t types.Type) bool {
if TypeIsSlice(t) {
return true
}
if t, ok := t.(*types.Named); ok {
return TypeIsDeeplySlice(t.Underlying())
}
return false
}
// TypeIsDeeplyMap returns true if `t` is of map type, including
// transitively through Named types
func TypeIsDeeplyMap(t types.Type) bool {
if _, ok := t.(*types.Map); ok {
return true
}
if t, ok := t.(*types.Named); ok {
return TypeIsDeeplyMap(t.Underlying())
}
return false
}
// TypeIsDeeplyPtr returns true if `t` is of pointer type, including
// transitively through Named types
func TypeIsDeeplyPtr(t types.Type) bool {
if _, ok := t.(*types.Pointer); ok {
return true
}
if t, ok := t.(*types.Named); ok {
return TypeIsDeeplyPtr(t.Underlying())
}
return false
}
// TypeIsDeeplyChan returns true if `t` is of channel type, including
// transitively through Named types
func TypeIsDeeplyChan(t types.Type) bool {
if _, ok := t.(*types.Chan); ok {
return true
}
if t, ok := t.(*types.Named); ok {
return TypeIsDeeplyChan(t.Underlying())
}
return false
}
// TypeAsDeeplyStruct returns underlying struct type if the type is struct type or a pointer to a struct type
// returns nil otherwise
func TypeAsDeeplyStruct(typ types.Type) *types.Struct {
if typ, ok := typ.(*types.Struct); ok {
return typ
}
if typ, ok := typ.(*types.Named); ok {
if resType, ok := typ.Underlying().(*types.Struct); ok {
return resType
}
}
if ptType, ok := typ.(*types.Pointer); ok {
if namedType, ok := ptType.Elem().(*types.Named); ok {
if resType, ok := namedType.Underlying().(*types.Struct); ok {
return resType
}
}
}
return nil
}
// TypeIsDeeplyInterface returns true if `t` is of struct type, including
// transitively through Named types
func TypeIsDeeplyInterface(t types.Type) bool {
if _, ok := t.(*types.Interface); ok {
return true
}
if t, ok := t.(*types.Named); ok {
return TypeIsDeeplyInterface(t.Underlying())
}
return false
}
// UnwrapPtr unwraps a pointer type and returns the element type. For all other types it returns
// the type unmodified.
func UnwrapPtr(t types.Type) types.Type {
if ptr, ok := t.(*types.Pointer); ok {
return ptr.Elem()
}
return t
}
// TypeOf returns the type of the passed AST expression
func TypeOf(pass *analysis.Pass, expr ast.Expr) types.Type {
return pass.TypesInfo.TypeOf(expr)
}
// FuncIdentFromCallExpr return a function identified from a call expression, nil otherwise
// nilable(result 0)
func FuncIdentFromCallExpr(expr *ast.CallExpr) *ast.Ident {
switch fun := expr.Fun.(type) {
case *ast.Ident:
return fun
case *ast.SelectorExpr:
return fun.Sel
default:
// case of anonymous function
return nil
}
}
// PartiallyQualifiedFuncName returns the name of the passed function, with the name of its receiver
// if defined
func PartiallyQualifiedFuncName(f *types.Func) string {
if sig, ok := f.Type().(*types.Signature); ok && sig.Recv() != nil {
return fmt.Sprintf("%s.%s", PortionAfterSep(sig.Recv().Type().String(), ".", 0), f.Name())
}
return f.Name()
}
// PortionAfterSep returns the suffix of the passed string `input` containing at most `occ` occurrences
// of the separator `sep`
func PortionAfterSep(input, sep string, occ int) string {
splits := strings.Split(input, sep)
n := len(splits)
if n <= occ+1 {
return input // input contains at most `occ` occurrences of `sep`
}
out := ""
for i := n - (1 + occ); i < n; i++ {
if len(out) > 0 {
out += sep
}
out += splits[i]
}
return out
}
// ExprIsAuthentic aims to return true iff the passed expression is an AST node
// found in the source program of this pass - not one that we created as an intermediate value.
// There is no fully sound way to do this - but returning whether it is present in the `Types` map
// map is a good approximation.
// Right now, this is used only to decide whether to print the location of the producer expression
// in a full trigger.
func ExprIsAuthentic(pass *analysis.Pass, expr ast.Expr) bool {
t := pass.TypesInfo.TypeOf(expr)
return t != nil
}
// StripParens takes an ast node and strips it of any outmost parentheses
func StripParens(expr ast.Node) ast.Node {
if parenExpr, ok := expr.(*ast.ParenExpr); ok {
return StripParens(parenExpr.X)
}
return expr
}
// IsSliceAppendCall checks if `node` represents the builtin append(slice []Type, elems ...Type) []Type
// call on a slice.
// The function checks 2 things,
// 1) Name of the called function is "builtin append"
// 2) The first argument to the function is a slice
func IsSliceAppendCall(node *ast.CallExpr, pass *analysis.Pass) (*types.Slice, bool) {
if funcName, ok := node.Fun.(*ast.Ident); ok {
if declObj := pass.TypesInfo.Uses[funcName]; declObj != nil {
if declObj.String() == "builtin append" {
if sliceType, ok := TypeOf(pass, node.Args[0]).(*types.Slice); ok {
return sliceType, true
}
}
}
}
return nil, false
}
// TypeBarsNilness returns false iff the type `t` is inhabited by nil.
func TypeBarsNilness(t types.Type) bool {
switch t := t.(type) {
case *types.Array:
return true
case *types.Slice:
return false
case *types.Pointer:
return false
case *types.Tuple:
return false
case *types.Signature:
return true // function-types are not inhabited by nil
case *types.Map:
return false
case *types.Chan:
return false
case *types.Named:
return TypeBarsNilness(t.Underlying())
case *types.Interface:
return false
case *types.Basic:
// all basic types except UntypedNil are not inhabited by nil
return t.Kind() != types.UntypedNil
default:
return true
}
}
// ExprBarsNilness returns if the expression can never be nil for the simple reason that nil does
// not inhabit its type.
func ExprBarsNilness(pass *analysis.Pass, expr ast.Expr) bool {
t := pass.TypesInfo.TypeOf(expr)
// `pass.TypesInfo.TypeOf` only checks Types, Uses, and Defs maps in TypesInfo. However, we may
// miss types for some expressions. For example, `f` in `s.f` can only be found in
// `pass.TypesInfo.Selections` map (see the comments of pass.TypesInfo.Types for more details).
// Be conservative for those cases for now.
// TODO: to investigate and find more cases.
if t == nil {
return false
}
return TypeBarsNilness(pass.TypesInfo.TypeOf(expr))
}
// FuncNumResults looks at a function declaration and returns the number of results of that function
func FuncNumResults(decl *types.Func) int {
return decl.Type().(*types.Signature).Results().Len()
}
// IsEmptyExpr checks if an expression is the empty identifier
func IsEmptyExpr(expr ast.Expr) bool {
if id, ok := expr.(*ast.Ident); ok {
if id.Name == "_" {
return true
}
}
return false
}
// funcIsRichCheckEffectReturning encodes the conditions that a function is deemed "rich-check-effect-returning", i.e.,
// it is an error-returning function or a bool(ok)-returning function.
// A function is deemed "rich-check-effect-returning" iff it has a single result of type `typName` (error or bool),
// and that result is the last in the list of results.
func funcIsRichCheckEffectReturning(fdecl *types.Func, expectedType types.Type) bool {
results := fdecl.Type().(*types.Signature).Results()
n := results.Len()
if n == 0 {
return false
}
if results.At(n-1).Type() != expectedType {
return false
}
for i := 0; i < n-1; i++ {
if results.At(i).Type() == expectedType {
return false
}
}
return true
}
// FuncIsErrReturning encodes the conditions that a function is deemed "error-returning".
// This guards its results to require an `err` check before use as nonnil.
// A function is deemed "error-returning" iff it has a single result of type `error`, and that
// result is the last in the list of results.
func FuncIsErrReturning(fdecl *types.Func) bool {
return funcIsRichCheckEffectReturning(fdecl, ErrorType)
}
// FuncIsOkReturning encodes the conditions that a function is deemed "ok-returning".
// This guards its results to require an `ok` check before use as nonnil.
// A function is deemed "ok-returning" iff it has a single result of type `bool`, and that
// result is the last in the list of results.
func FuncIsOkReturning(fdecl *types.Func) bool {
return funcIsRichCheckEffectReturning(fdecl, BoolType)
}
// IsFieldSelectorChain returns true if the expr is chain of idents. e.g, x.y.z
// It returns for false for expressions such as x.y().z
func IsFieldSelectorChain(expr ast.Expr) bool {
switch expr := expr.(type) {
case *ast.Ident:
return true
case *ast.SelectorExpr:
return IsFieldSelectorChain(expr.X)
default:
return false
}
}
// GetFieldVal returns the assigned value for the field at index. compElts holds the elements of the composite literal expression
// for struct initialization
func GetFieldVal(compElts []ast.Expr, fieldName string, numFields int, index int) ast.Expr {
for _, elt := range compElts {
if kv, ok := elt.(*ast.KeyValueExpr); ok {
if key, ok := kv.Key.(*ast.Ident); ok {
if key.Name == fieldName {
return kv.Value
}
}
}
}
// In this case the initialization is serial e.g. a = &A{p, q}
if numFields == len(compElts) {
return compElts[index]
}
return nil
}
// GetFunctionParamNode returns the ast param node matching the variable searchParam
func GetFunctionParamNode(funcDecl *ast.FuncDecl, searchParam *types.Var) ast.Expr {
for _, params := range funcDecl.Type.Params.List {
for _, param := range params.Names {
if searchParam.Name() == param.Name && param.Name != "" && param.Name != "_" {
return param
}
}
}
return nil
}
// GetParamObjFromIndex get the variable corresponding to the parameter from the function functionType
func GetParamObjFromIndex(functionType *types.Func, argIdx int) *types.Var {
fSig := functionType.Type().(*types.Signature)
functionParams := fSig.Params()
if argIdx < functionParams.Len() {
return functionParams.At(argIdx)
}
// In this case the argument is given to a variadic function and the object is last element of the param signature
if !fSig.Variadic() {
panic("Function is expected to be variadic in the case when argument index >= length of params")
}
return functionParams.At(functionParams.Len() - 1)
}
// GetSelectorExprHeadIdent gets the head of the chained selector expression if it is an ident. Returns nil otherwise
func GetSelectorExprHeadIdent(selExpr *ast.SelectorExpr) *ast.Ident {
if ident, ok := selExpr.X.(*ast.Ident); ok {
return ident
}
if x, ok := selExpr.X.(*ast.SelectorExpr); ok {
return GetSelectorExprHeadIdent(x)
}
return nil
}
// IsLiteral returns true if `expr` is a literal that matches with one of the given literal values (e.g., "nil", "true", "false)
func IsLiteral(expr ast.Expr, literals ...string) bool {
if ident, ok := expr.(*ast.Ident); ok {
for _, literal := range literals {
if ident.Name == literal {
return true
}
}
}
return false
}
// TruncatePosition truncates the prefix of the filename to keep it at the given depth (config.DirLevelsToPrintForTriggers)
func TruncatePosition(position token.Position) token.Position {
position.Filename = PortionAfterSep(
position.Filename, "/",
config.DirLevelsToPrintForTriggers)
return position
}
var codeReferencePattern = regexp.MustCompile("\\`(.*?)\\`")
var pathPattern = regexp.MustCompile(`"(.*?)"`)
var nilabilityPattern = regexp.MustCompile(`([\(|^\t](?i)(found\s|must\sbe\s)(nilable|nonnil)[\)]?)`)
// PrettyPrintErrorMessage is used in error reporting to post process and pretty print the output with colors
func PrettyPrintErrorMessage(msg string) string {
// TODO: below string parsing should not be required after is implemented
errorStr := fmt.Sprintf("\x1b[%dm%s\x1b[0m", 31, "error: ") // red
codeStr := fmt.Sprintf("\u001B[%dm%s\u001B[0m", 95, "`${1}`") // magenta
pathStr := fmt.Sprintf("\u001B[%dm%s\u001B[0m", 36, "${1}") // cyan
nilabilityStr := fmt.Sprintf("\u001B[%dm%s\u001B[0m", 1, "${1}") // bold
msg = nilabilityPattern.ReplaceAllString(msg, nilabilityStr)
msg = codeReferencePattern.ReplaceAllString(msg, codeStr)
msg = pathPattern.ReplaceAllString(msg, pathStr)
msg = errorStr + msg
return msg
}
// truncatePosition removes part of prefix of the full file path, determined by
// config.DirLevelsToPrintForTriggers.
func truncatePosition(position token.Position) token.Position {
position.Filename = PortionAfterSep(
position.Filename, "/",
config.DirLevelsToPrintForTriggers)
return position
}
// PosToLocation converts a token.Pos as a real code location, of token.Position.
func PosToLocation(pos token.Pos, pass *analysis.Pass) token.Position {
return truncatePosition(pass.Fset.Position(pos))
}