forked from sqlc-dev/sqlc
/
mysql.go
84 lines (68 loc) · 1.6 KB
/
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package sqltest
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/StanVerse/sqld/internal/sql/sqlpath"
_ "github.com/go-sql-driver/mysql"
)
func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) {
t.Helper()
data := os.Getenv("MYSQL_DATABASE")
host := os.Getenv("MYSQL_HOST")
pass := os.Getenv("MYSQL_ROOT_PASSWORD")
port := os.Getenv("MYSQL_PORT")
user := os.Getenv("MYSQL_USER")
if user == "" {
user = "root"
}
if pass == "" {
pass = "mysecretpassword"
}
if port == "" {
port = "3306"
}
if host == "" {
host = "127.0.0.1"
}
if data == "" {
data = "dinotest"
}
source := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, data)
t.Logf("db: %s", source)
db, err := sql.Open("mysql", source)
if err != nil {
t.Fatal(err)
}
// For each test, pick a new database name at random.
dbName := "sqltest_mysql_" + id()
if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
t.Fatal(err)
}
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbName)
sdb, err := sql.Open("mysql", source)
if err != nil {
t.Fatal(err)
}
files, err := sqlpath.Glob(migrations)
if err != nil {
t.Fatal(err)
}
for _, f := range files {
blob, err := os.ReadFile(f)
if err != nil {
t.Fatal(err)
}
if _, err := sdb.Exec(string(blob)); err != nil {
t.Fatalf("%s: %s", filepath.Base(f), err)
}
}
return sdb, func() {
// Drop the test db after test runs
if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
t.Fatal(err)
}
}
}