/
table.go
194 lines (157 loc) · 5.33 KB
/
table.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
package schema
import (
"database/sql"
"fmt"
"strings"
"github.com/juju/errors"
)
// Connection to the DB.
type Connection struct {
db *sql.DB
tablesEngine string
}
// NewConnection stores a connection to a DB to apply schema changes to it.
func NewConnection(db *sql.DB) *Connection {
return &Connection{
db: db,
tablesEngine: "InnoDB",
}
}
// NewTestConnection stores a connection to a DB to apply schema changes to it. All
// tables created with this connection will reside in memory for tests.
func NewTestConnection(db *sql.DB) *Connection {
return &Connection{
db: db,
tablesEngine: "MEMORY",
}
}
// Column it's the common interface between all type of columns.
type Column interface {
// SQL generates the SQL needed to create the column.
SQL() string
}
// CreateTable creates a new table.
func (conn *Connection) CreateTable(name string, columns []Column) error {
lines := make([]string, len(columns))
for i, col := range columns {
lines[i] = col.SQL()
}
stmt := fmt.Sprintf("CREATE TABLE `%s` (%s) ENGINE=%s DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", name, strings.Join(lines, ","), conn.tablesEngine)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// DropTable drops a table.
func (conn *Connection) DropTable(name string) error {
stmt := fmt.Sprintf("DROP TABLE `%s`", name)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// TableExists checks if the table already exists in the DB.
func (conn *Connection) TableExists(name string) (bool, error) {
var n string
err := conn.db.QueryRow(fmt.Sprintf("SHOW TABLES LIKE '%s';", name)).Scan(&n)
switch {
case err == sql.ErrNoRows:
return false, nil
case err != nil:
return false, errors.Trace(err)
default:
return true, nil
}
}
// CreateTableIfNotExists creates the table if it is not already present.
func (conn *Connection) CreateTableIfNotExists(name string, columns []Column) error {
exists, err := conn.TableExists(name)
switch {
case err != nil:
return errors.Trace(err)
case !exists:
return errors.Trace(conn.CreateTable(name, columns))
}
return nil
}
// RenameColumn changes the name of a column. It needs the current type of the column.
// It is not recommended to change the type manually with that string (though it's possible).
func (conn *Connection) RenameColumn(tableName, oldColumnName, columnName, columnType string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` CHANGE `%s` `%s` %s", tableName, oldColumnName, columnName, columnType)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// AddColumn creates a new column in a table that already exists.
func (conn *Connection) AddColumn(tableName string, col Column) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` ADD %s", tableName, col.SQL())
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// AlterColumn changes the properties of a column. It needs the current type of the column.
func (conn *Connection) AlterColumn(tableName, columnName, columnType string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` MODIFY `%s` %s", tableName, columnName, columnType)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// DropColumn removes a column from a table.
func (conn *Connection) DropColumn(tableName, columnName string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN %s", tableName, columnName)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// DropPrimaryKey removes the primary key from a table (not the column, only the index).
func (conn *Connection) DropPrimaryKey(name string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP PRIMARY KEY", name)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// AssignPrimaryKey sets the new primary key of the table. It should have been dropped
// before, or not exist previously.
func (conn *Connection) AssignPrimaryKey(tableName string, columnNames []string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` ADD PRIMARY KEY (%s)", tableName, quoteCols(columnNames))
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// AddUnique adds a new unique index to a column.
func (conn *Connection) AddUnique(tableName, indexName string, columnNames []string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` ADD UNIQUE INDEX %s(%s)", tableName, indexName, quoteCols(columnNames))
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// DropUnique removes the unique index of a column.
func (conn *Connection) DropUnique(tableName, indexName string) error {
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX %s", tableName, indexName)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
// RenameTable changes the name of the table to a new one.
func (conn *Connection) RenameTable(oldName, name string) error {
stmt := fmt.Sprintf("RENAME TABLE `%s` TO `%s`", oldName, name)
if _, err := conn.db.Exec(stmt); err != nil {
return errors.Trace(err)
}
return nil
}
func quoteCols(cols []string) string {
result := []string{}
for _, col := range cols {
result = append(result, fmt.Sprintf("`%s`", col))
}
return strings.Join(result, ",")
}