-
Notifications
You must be signed in to change notification settings - Fork 230
/
myparse.go
118 lines (110 loc) · 3.34 KB
/
myparse.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package myparse
import (
"fmt"
"ariga.io/atlas/cmd/atlas/internal/sqlparse/parsefix"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
_ "github.com/pingcap/tidb/parser/test_driver"
)
// FixChange fixes the changes according to the given statement.
func FixChange(d migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) {
stmt, err := parser.New().ParseOneStmt(s, "", "")
if err != nil {
return nil, err
}
if len(changes) == 0 {
return changes, nil
}
switch stmt := stmt.(type) {
case *ast.AlterTableStmt:
if changes, err = renameTable(d, stmt, changes); err != nil {
return nil, err
}
modify, ok := changes[0].(*schema.ModifyTable)
if !ok {
return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0])
}
for _, r := range renameColumns(stmt) {
parsefix.RenameColumn(modify, r)
}
for _, r := range renameIndexes(stmt) {
parsefix.RenameIndex(modify, r)
}
case *ast.RenameTableStmt:
for _, t := range stmt.TableToTables {
changes = parsefix.RenameTable(
changes,
&parsefix.Rename{
From: t.OldTable.Name.O,
To: t.NewTable.Name.O,
})
}
}
return changes, nil
}
// renameColumns returns all renamed columns that exist in the statement.
func renameColumns(stmt *ast.AlterTableStmt) (rename []*parsefix.Rename) {
for _, s := range stmt.Specs {
if s.Tp == ast.AlterTableRenameColumn {
rename = append(rename, &parsefix.Rename{
From: s.OldColumnName.Name.O,
To: s.NewColumnName.Name.O,
})
}
}
return
}
// renameIndexes returns all renamed indexes that exist in the statement.
func renameIndexes(stmt *ast.AlterTableStmt) (rename []*parsefix.Rename) {
for _, s := range stmt.Specs {
if s.Tp == ast.AlterTableRenameIndex {
rename = append(rename, &parsefix.Rename{
From: s.FromKey.O,
To: s.ToKey.O,
})
}
}
return
}
// renameTable fixes the changes from ALTER command with RENAME into ModifyTable and RenameTable.
func renameTable(drv migrate.Driver, stmt *ast.AlterTableStmt, changes schema.Changes) (schema.Changes, error) {
var r *ast.AlterTableSpec
for _, s := range stmt.Specs {
if s.Tp == ast.AlterTableRenameTable {
r = s
break
}
}
if r == nil {
return changes, nil
}
if len(changes) != 2 {
return nil, fmt.Errorf("unexected number fo changes for ALTER command with RENAME clause: %d", len(changes))
}
i, j := changes.IndexDropTable(stmt.Table.Name.O), changes.IndexAddTable(r.NewTable.Name.O)
if i == -1 {
return nil, fmt.Errorf("DropTable %q change was not found in changes", stmt.Table.Name)
}
if j == -1 {
return nil, fmt.Errorf("AddTable %q change was not found in changes", r.NewTable.Name)
}
fromT, toT := changes[0].(*schema.DropTable).T, changes[1].(*schema.AddTable).T
fromT.Name = toT.Name
diff, err := drv.TableDiff(fromT, toT)
if err != nil {
return nil, err
}
changeT := *toT
changeT.Name = stmt.Table.Name.O
return schema.Changes{
// Modify the table first.
&schema.ModifyTable{T: &changeT, Changes: diff},
// Then, apply the RENAME.
&schema.RenameTable{From: &changeT, To: toT},
}, nil
}