-
Notifications
You must be signed in to change notification settings - Fork 2
/
mysql.go
49 lines (43 loc) · 1.32 KB
/
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package sqltest
import (
"context"
"fmt"
"testing"
"time"
)
// MySQLTruncator represents a set of methods for truncating for MySQL specifically.
type MySQLTruncator struct {
agent Agent
}
// MustTruncateAll will run TruncateAll and fail test if it's unsuccessful.
func (tr *MySQLTruncator) MustTruncateAll(t testing.TB) {
mustTruncateAll(t, tr)
}
// TruncateAll will empty all tables in the database.
func (tr *MySQLTruncator) TruncateAll(t testing.TB) error {
return truncateAll(t, tr, tr.agent)
}
// MustTruncateTables will run TruncateTables and will fail test if it can't.
func (tr *MySQLTruncator) MustTruncateTables(t testing.TB, tables ...string) {
mustTruncateTables(t, tr, tables...)
}
// TruncateTables removes all content in the given tables.
func (tr *MySQLTruncator) TruncateTables(t testing.TB, tables ...string) error {
const (
setForeignKeysStmt = "SET FOREIGN_KEY_CHECKS=?"
)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := tr.agent.ExecContext(ctx, setForeignKeysStmt, false); err != nil {
return err
}
for _, table := range tables {
if _, err := tr.agent.ExecContext(ctx, fmt.Sprintf(truncateStmtFmt, table)); err != nil {
return err
}
}
if _, err := tr.agent.ExecContext(ctx, setForeignKeysStmt, true); err != nil {
return err
}
return nil
}