forked from go-gorm/gen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gentool.go
243 lines (220 loc) · 7.38 KB
/
gentool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
package main
import (
"flag"
"fmt"
"log"
"os"
"strings"
"github.com/Edward-Alphonse/gen"
"gopkg.in/yaml.v3"
"gorm.io/driver/clickhouse"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
)
// DBType database type
type DBType string
const (
// dbMySQL Gorm Drivers mysql || postgres || sqlite || sqlserver
dbMySQL DBType = "mysql"
dbPostgres DBType = "postgres"
dbSQLite DBType = "sqlite"
dbSQLServer DBType = "sqlserver"
dbClickHouse DBType = "clickhouse"
)
const (
defaultQueryPath = "./dao/query"
)
// CmdParams is command line parameters
type CmdParams struct {
DSN string `yaml:"dsn"` // consult[https://gorm.io/docs/connecting_to_the_database.html]"
DB string `yaml:"db"` // input mysql or postgres or sqlite or sqlserver. consult[https://gorm.io/docs/connecting_to_the_database.html]
Tables []string `yaml:"tables"` // enter the required data table or leave it blank
OnlyModel bool `yaml:"onlyModel"` // only generate model
OutPath string `yaml:"outPath"` // specify a directory for output
OutFile string `yaml:"outFile"` // query code file name, default: gen.go
WithUnitTest bool `yaml:"withUnitTest"` // generate unit test for query code
ModelPkgName string `yaml:"modelPkgName"` // generated model code's package name
FieldNullable bool `yaml:"fieldNullable"` // generate with pointer when field is nullable
FieldCoverable bool `yaml:"fieldCoverable"` // generate with pointer when field has default value
FieldWithIndexTag bool `yaml:"fieldWithIndexTag"` // generate field with gorm index tag
FieldWithTypeTag bool `yaml:"fieldWithTypeTag"` // generate field with gorm column type tag
FieldSignable bool `yaml:"fieldSignable"` // detect integer field's unsigned type, adjust generated data type
}
func (c *CmdParams) revise() *CmdParams {
if c == nil {
return c
}
if c.DB == "" {
c.DB = string(dbMySQL)
}
if c.OutPath == "" {
c.OutPath = defaultQueryPath
}
if len(c.Tables) == 0 {
return c
}
tableList := make([]string, 0, len(c.Tables))
for _, tableName := range c.Tables {
_tableName := strings.TrimSpace(tableName) // trim leading and trailing space in tableName
if _tableName == "" { // skip empty tableName
continue
}
tableList = append(tableList, _tableName)
}
c.Tables = tableList
return c
}
// YamlConfig is yaml config struct
type YamlConfig struct {
Version string `yaml:"version"` //
Database *CmdParams `yaml:"database"` //
}
// connectDB choose db type for connection to database
func connectDB(t DBType, dsn string) (*gorm.DB, error) {
if dsn == "" {
return nil, fmt.Errorf("dsn cannot be empty")
}
switch t {
case dbMySQL:
return gorm.Open(mysql.Open(dsn))
case dbPostgres:
return gorm.Open(postgres.Open(dsn))
case dbSQLite:
return gorm.Open(sqlite.Open(dsn))
case dbSQLServer:
return gorm.Open(sqlserver.Open(dsn))
case dbClickHouse:
return gorm.Open(clickhouse.Open(dsn))
default:
return nil, fmt.Errorf("unknow db %q (support mysql || postgres || sqlite || sqlserver for now)", t)
}
}
// genModels is gorm/gen generated models
func genModels(g *gen.Generator, db *gorm.DB, tables []string) (models []interface{}, err error) {
if len(tables) == 0 {
// Execute tasks for all tables in the database
tables, err = db.Migrator().GetTables()
if err != nil {
return nil, fmt.Errorf("GORM migrator get all tables fail: %w", err)
}
}
// Execute some data table tasks
models = make([]interface{}, len(tables))
for i, tableName := range tables {
models[i] = g.GenerateModel(tableName)
}
return models, nil
}
// parseCmdFromYaml parse cmd param from yaml
func parseCmdFromYaml(path string) *CmdParams {
file, err := os.Open(path)
if err != nil {
log.Fatalf("parseCmdFromYaml fail %s", err.Error())
return nil
}
defer file.Close() // nolint
var yamlConfig YamlConfig
if err = yaml.NewDecoder(file).Decode(&yamlConfig); err != nil {
log.Fatalf("parseCmdFromYaml fail %s", err.Error())
return nil
}
return yamlConfig.Database
}
// argParse is parser for cmd
func argParse() *CmdParams {
// choose is file or flag
genPath := flag.String("c", "", "is path for gen.yml")
dsn := flag.String("dsn", "", "consult[https://gorm.io/docs/connecting_to_the_database.html]")
db := flag.String("db", string(dbMySQL), "input mysql|postgres|sqlite|sqlserver|clickhouse. consult[https://gorm.io/docs/connecting_to_the_database.html]")
tableList := flag.String("tables", "", "enter the required data table or leave it blank")
onlyModel := flag.Bool("onlyModel", false, "only generate models (without query file)")
outPath := flag.String("outPath", defaultQueryPath, "specify a directory for output")
outFile := flag.String("outFile", "", "query code file name, default: gen.go")
withUnitTest := flag.Bool("withUnitTest", false, "generate unit test for query code")
modelPkgName := flag.String("modelPkgName", "", "generated model code's package name")
fieldNullable := flag.Bool("fieldNullable", false, "generate with pointer when field is nullable")
fieldCoverable := flag.Bool("fieldCoverable", false, "generate with pointer when field has default value")
fieldWithIndexTag := flag.Bool("fieldWithIndexTag", false, "generate field with gorm index tag")
fieldWithTypeTag := flag.Bool("fieldWithTypeTag", false, "generate field with gorm column type tag")
fieldSignable := flag.Bool("fieldSignable", false, "detect integer field's unsigned type, adjust generated data type")
flag.Parse()
if *genPath != "" { //use yml config
return parseCmdFromYaml(*genPath)
}
var cmdParse CmdParams
// cmd first
if *dsn != "" {
cmdParse.DSN = *dsn
}
if *db != "" {
cmdParse.DB = *db
}
if *tableList != "" {
cmdParse.Tables = strings.Split(*tableList, ",")
}
if *onlyModel {
cmdParse.OnlyModel = true
}
if *outPath != "" {
cmdParse.OutPath = *outPath
}
if *outFile != "" {
cmdParse.OutFile = *outFile
}
if *withUnitTest {
cmdParse.WithUnitTest = *withUnitTest
}
if *modelPkgName != "" {
cmdParse.ModelPkgName = *modelPkgName
}
if *fieldNullable {
cmdParse.FieldNullable = *fieldNullable
}
if *fieldCoverable {
cmdParse.FieldCoverable = *fieldCoverable
}
if *fieldWithIndexTag {
cmdParse.FieldWithIndexTag = *fieldWithIndexTag
}
if *fieldWithTypeTag {
cmdParse.FieldWithTypeTag = *fieldWithTypeTag
}
if *fieldSignable {
cmdParse.FieldSignable = *fieldSignable
}
return &cmdParse
}
func main() {
// cmdParse
config := argParse().revise()
if config == nil {
log.Fatalln("parse config fail")
}
db, err := connectDB(DBType(config.DB), config.DSN)
if err != nil {
log.Fatalln("connect db server fail:", err)
}
g := gen.NewGenerator(gen.Config{
OutPath: config.OutPath,
OutFile: config.OutFile,
ModelPkgPath: config.ModelPkgName,
WithUnitTest: config.WithUnitTest,
FieldNullable: config.FieldNullable,
FieldCoverable: config.FieldCoverable,
FieldWithIndexTag: config.FieldWithIndexTag,
FieldWithTypeTag: config.FieldWithTypeTag,
FieldSignable: config.FieldSignable,
})
g.UseDB(db)
models, err := genModels(g, db, config.Tables)
if err != nil {
log.Fatalln("get tables info fail:", err)
}
if !config.OnlyModel {
g.ApplyBasic(models...)
}
g.Execute()
}