-
Notifications
You must be signed in to change notification settings - Fork 194
/
statements.go
81 lines (71 loc) · 1.82 KB
/
statements.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/
package astbuilder
import (
"fmt"
"github.com/dave/dst"
)
// Statements creates a slice of dst.Stmt by combining the given statements.
// Pass any combination of dst.Stmt and []dst.Stmt as arguments; anything else will
// result in a runtime panic.
func Statements(statements ...any) []dst.Stmt {
// Calculate the final size required
size := 0
for _, s := range statements {
switch s := s.(type) {
case nil:
// Skip nils
continue
case dst.Stmt:
size++
case []dst.Stmt:
size += len(s)
case dst.Decl:
size++
default:
panic(fmt.Sprintf("expected dst.Stmt, []dst.Stmt, or dst.Decl but found %T", s))
}
}
// Flatten the statements into a single slice
stmts := make([]dst.Stmt, 0, size)
for _, s := range statements {
switch s := s.(type) {
case nil:
// Skip nils
continue
case dst.Stmt:
// Add a single statement
stmts = append(stmts, s)
case []dst.Stmt:
// Add many statements
stmts = append(stmts, s...)
case dst.Decl:
// Convert declaration to statement
stmt := &dst.DeclStmt{Decl: s}
stmts = append(stmts, stmt)
default:
panic(fmt.Sprintf("expected dst.Stmt, []dst.Stmt, or dst.Decl but found %T", s))
}
}
// Clone everything to avoid sharing nodes
result := make([]dst.Stmt, 0, len(stmts))
for _, st := range stmts {
result = append(result, dst.Clone(st).(dst.Stmt))
}
return result
}
// StatementBlock generates a block containing the supplied statements
// If we're given a single statement that's already a block, we won't double wrap it
func StatementBlock(statements ...dst.Stmt) *dst.BlockStmt {
stmts := Statements(statements)
if len(stmts) == 1 {
if block, ok := stmts[0].(*dst.BlockStmt); ok {
return block
}
}
return &dst.BlockStmt{
List: stmts,
}
}