/
migrations.go
130 lines (102 loc) · 2.71 KB
/
migrations.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package db
import (
"fmt"
"time"
"github.com/go-sql-driver/mysql"
"github.com/gobuffalo/packr"
log "github.com/Sirupsen/logrus"
)
var dbAssets = packr.NewBox("./migrations")
// CheckExists queries the database to see if a migration table with this version id exists already
func (version *Version) CheckExists() (bool, error) {
exists, err := Mysql.SelectInt("select count(1) as ex from migrations where version=?", version.VersionString())
if err != nil {
switch err.(type) {
case *mysql.MySQLError:
// 1146 is mysql table does not exist
if err.(*mysql.MySQLError).Number != 1146 {
return false, err
}
fmt.Println("Creating migrations table")
if _, err = Mysql.Exec(initialSQL); err != nil {
panic(err)
}
return version.CheckExists()
default:
return false, err
}
}
return exists > 0, nil
}
// Run executes a database migration
func (version *Version) Run() error {
fmt.Printf("Executing migration %s (at %v)...\n", version.HumanoidVersion(), time.Now())
tx, err := Mysql.Begin()
if err != nil {
return err
}
sql := version.GetSQL(version.GetPath())
for i, query := range sql {
fmt.Printf("\r [%d/%d]", i+1, len(sql))
if len(query) == 0 {
continue
}
if _, err := tx.Exec(query); err != nil {
handleRollbackError(tx.Rollback())
log.Warnf("\n ERR! Query: %v\n\n", query)
return err
}
}
if _, err := tx.Exec("insert into migrations set version=?, upgraded_date=?", version.VersionString(), time.Now()); err != nil {
handleRollbackError(tx.Rollback())
return err
}
fmt.Println()
return tx.Commit()
}
func handleRollbackError(err error){
if err != nil {
log.Warn(err.Error())
}
}
// TryRollback attempts to rollback the database to an earlier version if a rollback exists
func (version *Version) TryRollback() {
fmt.Printf("Rolling back %s (time: %v)...\n", version.HumanoidVersion(), time.Now())
data := dbAssets.Bytes(version.GetErrPath())
if len(data) == 0 {
fmt.Println("Rollback SQL does not exist.")
fmt.Println()
return
}
sql := version.GetSQL(version.GetErrPath())
for _, query := range sql {
fmt.Printf(" [ROLLBACK] > %v\n", query)
if _, err := Mysql.Exec(query); err != nil {
fmt.Println(" [ROLLBACK] - Stopping")
return
}
}
}
// MigrateAll checks for db migrations and executes them
func MigrateAll() error {
fmt.Println("Checking DB migrations")
didRun := false
// go from beginning to the end
for _, version := range Versions {
if exists, err := version.CheckExists(); err != nil || exists {
if exists {
continue
}
return err
}
didRun = true
if err := version.Run(); err != nil {
version.TryRollback()
return err
}
}
if didRun {
fmt.Println("Migrations Finished")
}
return nil
}