/
schema.go
299 lines (272 loc) · 7.94 KB
/
schema.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
/*
* Copyright(c) 2019 Lianjia, Inc. All Rights Reserved
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rebuild
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"regexp"
"strings"
"github.com/juju/errors"
"github.com/LianjiaTech/lightning/common"
// database/sql
_ "github.com/go-sql-driver/mysql"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
)
// Schemas ...
var Schemas map[string]*ast.CreateTableStmt
// Columns ...
var Columns map[string][]string
// PrimaryKeys ...
var PrimaryKeys map[string][]string
// LoadSchemaInfo load schema info from file or mysql
func LoadSchemaInfo() {
if common.Config.MySQL.SchemaFile != "" {
// load from file
err := loadSchemaFromFile()
if err != nil {
common.Log.Error(errors.Trace(err).Error())
}
return
} else {
// load from mysql server
err := loadSchemaFromMySQL()
if err != nil {
common.Log.Error(errors.Trace(err).Error())
}
}
}
func loadSchemaFromFile() error {
common.Log.Debug("loadSchemaFromFile %s", common.Config.MySQL.SchemaFile)
if _, err := os.Stat(common.Config.MySQL.SchemaFile); err != nil {
return err
}
buf, err := ioutil.ReadFile(common.Config.MySQL.SchemaFile)
if err != nil {
return err
}
err = schemaAppend("", string(buf))
buildColumns()
buildPrimaryKeys()
return err
}
func loadSchemaFromMySQL() error {
var databases []string
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=%s&timeout=5s",
common.MasterInfo.MasterUser,
common.MasterInfo.MasterPassword,
common.MasterInfo.MasterHost,
common.MasterInfo.MasterPort,
common.Config.Global.Charset,
)
common.Log.Debug("loadSchemaFromMySQL %s", dsn)
db, err := sql.Open("mysql", dsn)
if err != nil {
return err
}
defer db.Close()
res, err := db.Query("SHOW DATABASES;")
if err != nil {
return err
}
for res.Next() {
var database string
res.Scan(&database)
switch database {
case "information_schema", "sys", "mysql", "performance_schema":
default:
databases = append(databases, database)
}
}
res.Close()
for _, database := range databases {
res, err := db.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database))
if err != nil {
common.Log.Error(errors.Trace(err).Error())
continue
}
// SHOW TABLES
var tables []string
for res.Next() {
var table string
err = res.Scan(&table)
if err != nil {
common.Log.Error(errors.Trace(err).Error())
continue
}
tables = append(tables, table)
}
// SHOW CREATE TABLE
for _, table := range tables {
var ignore bool
if len(common.Config.Filters.Tables) > 0 {
ignore = true
}
for _, tb := range common.Config.Filters.Tables {
// TODO: % 匹配有点复杂,这里暂且加载所有表结构
if strings.Contains(tb, "%") && strings.HasPrefix(strings.Replace(tb, "`", "", -1), database+".") {
ignore = false
break
}
// 对于表较多的情况,只加载需要的表将极大加速表结构加载速度
if strings.Replace(tb, "`", "", -1) == fmt.Sprintf("%s.%s", database, table) {
ignore = false
break
}
}
if ignore {
continue
}
tableRes, err := db.Query(fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`;", database, table))
if err != nil {
common.Log.Error(errors.Trace(err).Error())
continue
}
cols, err := tableRes.Columns()
if err != nil {
common.Log.Error(errors.Trace(err).Error())
continue
}
// SHOW CREATE VIEW WILL GET 4 COLUMNS
if len(cols) != 2 {
common.Log.Info("by pass host: %s, port: %d, database: %s, table: %s",
common.MasterInfo.MasterHost,
common.MasterInfo.MasterPort,
database, table)
continue
}
for tableRes.Next() {
var name, schema string
err = tableRes.Scan(&name, &schema)
if err != nil {
common.Log.Error("host: %s, port: %d, database: %s, table: %s, error: %s",
common.MasterInfo.MasterHost,
common.MasterInfo.MasterPort,
database, table, errors.Trace(err).Error())
continue
}
err = schemaAppend(database, schema)
if err != nil {
common.Log.Error("host: %s, port: %d, database: %s, table: %s, sql: %s, error: %s",
common.MasterInfo.MasterHost,
common.MasterInfo.MasterPort,
database, table,
schema,
errors.Trace(err).Error())
schemaAppend(database, buildFakeTable(db, fmt.Sprintf("`%s`.`%s`", database, table)))
}
}
tableRes.Close()
}
res.Close()
}
buildColumns()
buildPrimaryKeys()
return nil
}
func schemaAppend(database, sql string) error {
sql = removeIncompatibleWords(sql)
stmts, err := TiParse(sql, common.Config.Global.Charset, mysql.Charsets[common.Config.Global.Charset])
if err != nil {
return err
}
if database == "" {
database = "%"
}
for _, stmt := range stmts {
switch node := stmt.(type) {
case *ast.CreateTableStmt:
if node.Table.Schema.String() == "" {
node.Table.Schema = model.NewCIStr(database)
}
Schemas[fmt.Sprintf("`%s`.`%s`", database, node.Table.Name)] = node
case *ast.UseStmt:
database = node.DBName
}
}
return nil
}
// removeIncompatibleWords remove pingcap/parser not support words from schema
// Note: only for MySQL `SHOW CREATE TABLE` hand-writing SQL not compatible
func removeIncompatibleWords(sql string) string {
// CONSTRAINT col_fk FOREIGN KEY (col) REFERENCES tb (id) ON UPDATE CASCADE
re := regexp.MustCompile(` ON UPDATE CASCADE`)
sql = re.ReplaceAllString(sql, "")
// FULLTEXT KEY col_fk (col) /*!50100 WITH PARSER `ngram` */
// /*!50100 PARTITION BY LIST (col)
re = regexp.MustCompile(`/\*!5`)
sql = re.ReplaceAllString(sql, "/* 5")
// col varchar(10) CHARACTER SET gbk DEFAULT NULL
re = regexp.MustCompile(`CHARACTER SET [a-z_0-9]* `)
sql = re.ReplaceAllString(sql, "")
return sql
}
// buildColumns build column name list
func buildColumns() {
Columns = make(map[string][]string)
for _, schema := range Schemas {
table := fmt.Sprintf("`%s`.`%s`", schema.Table.Schema.String(), schema.Table.Name.String())
for _, col := range schema.Cols {
Columns[table] = append(Columns[table], fmt.Sprintf("`%s`", col.Name.String()))
}
}
}
// buildPrimaryKeys build primary key list
func buildPrimaryKeys() {
PrimaryKeys = make(map[string][]string)
for _, schema := range Schemas {
table := fmt.Sprintf("`%s`.`%s`", schema.Table.Schema.String(), schema.Table.Name.String())
for _, con := range schema.Constraints {
if con.Tp == ast.ConstraintPrimaryKey {
for _, col := range con.Keys {
PrimaryKeys[table] = append(PrimaryKeys[table], fmt.Sprintf("`%s`", col.Column.String()))
}
}
}
// 如果表没有主键,把表的所有列合起来当主键
if len(PrimaryKeys[table]) == 0 {
PrimaryKeys[table] = Columns[table]
}
}
}
// buildFakeTable ...
func buildFakeTable(db *sql.DB, table string) string {
var col, key string
var t []byte
var columns, primary []string
res, err := db.Query(fmt.Sprintf("SHOW COLUMNS FROM %s", table))
if err != nil {
common.Log.Error(err.Error())
return ""
}
defer res.Close()
for res.Next() {
res.Scan(&col, &t, &t, &key, &t, &t)
columns = append(columns, fmt.Sprintf("`%s` INT", col))
if key == "PRI" {
primary = append(primary, fmt.Sprintf("`%s`", col))
}
}
return fmt.Sprintf("CREATE TABLE %s (%s %s);", table, strings.Join(columns, ","), fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primary, ",")))
}
func onlyTable(table string) string {
tup := strings.Split(strings.Trim(table, "`"), "`.`")
length := len(tup)
if length <= 0 {
return ""
}
return fmt.Sprint("`", tup[length-1], "`")
}