Skip to content

Commit

Permalink
Support embedded migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
amacneil committed Feb 21, 2023
1 parent 447ea69 commit be85b6b
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 196 deletions.
210 changes: 84 additions & 126 deletions pkg/dbmate/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
Expand All @@ -30,7 +31,7 @@ var (
)

// migrationFileRegexp pattern for valid migration files
var migrationFileRegexp = regexp.MustCompile(`^\d.*\.sql$`)
var migrationFileRegexp = regexp.MustCompile(`^(\d+).*\.sql$`)

// DB allows dbmate actions to be performed on a specified database
type DB struct {
Expand All @@ -42,6 +43,8 @@ type DB struct {
Log io.Writer
// MigrationsDir specifies the directory to find migration files
MigrationsDir string
// MigrationsFS allows overriding the filesystem
MigrationsFS fs.FS
// MigrationsTableName specifies the database table to record migrations in
MigrationsTableName string
// SchemaFile specifies the location for schema.sql file
Expand All @@ -67,13 +70,14 @@ func New(databaseURL *url.URL) *DB {
return &DB{
AutoDumpSchema: true,
DatabaseURL: databaseURL,
Log: os.Stdout,
MigrationsDir: "./db/migrations",
MigrationsFS: os.DirFS("."),
MigrationsTableName: "schema_migrations",
SchemaFile: "./db/schema.sql",
WaitBefore: false,
WaitInterval: time.Second,
WaitTimeout: 60 * time.Second,
Log: os.Stdout,
}
}

Expand Down Expand Up @@ -290,59 +294,52 @@ func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {

// Migrate migrates database to the latest version
func (db *DB) Migrate() error {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
drv, err := db.Driver()
if err != nil {
return err
}

if len(files) == 0 {
return ErrNoMigrationFiles
}

drv, err := db.Driver()
migrations, err := db.FindMigrations()
if err != nil {
return err
}

sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
if len(migrations) == 0 {
return ErrNoMigrationFiles
}
defer dbutil.MustClose(sqlDB)

applied, err := drv.SelectMigrations(sqlDB, -1)
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
}
defer dbutil.MustClose(sqlDB)

for _, filename := range files {
ver := migrationVersion(filename)
if ok := applied[ver]; ok {
// migration already applied
for _, migration := range migrations {
if migration.Applied {
continue
}

fmt.Fprintf(db.Log, "Applying: %s\n", filename)
fmt.Fprintf(db.Log, "Applying: %s\n", migration.FileName)

up, _, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
parsed, err := migration.Parse()
if err != nil {
return err
}

execMigration := func(tx dbutil.Transaction) error {
// run actual migration
result, err := tx.Exec(up.Contents)
result, err := tx.Exec(parsed.Up)
if err != nil {
return err
} else if db.Verbose {
db.printVerbose(result)
}

// record migration
return drv.InsertMigration(tx, ver)
return drv.InsertMigration(tx, migration.Version)
}

if up.Options.Transaction() {
if parsed.UpOptions.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
Expand Down Expand Up @@ -374,53 +371,69 @@ func (db *DB) printVerbose(result sql.Result) {
}
}

func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
files, err := os.ReadDir(dir)
// FindMigrations lists all available migrations
func (db *DB) FindMigrations() ([]Migration, error) {
drv, err := db.Driver()
if err != nil {
return nil, err
}

sqlDB, err := drv.Open()
if err != nil {
return nil, err
}
defer dbutil.MustClose(sqlDB)

// find applied migrations
appliedMigrations := map[string]bool{}
migrationsTableExists, err := drv.MigrationsTableExists(sqlDB)
if err != nil {
return nil, fmt.Errorf("%w `%s`", ErrMigrationDirNotFound, dir)
return nil, err
}

if migrationsTableExists {
appliedMigrations, err = drv.SelectMigrations(sqlDB, -1)
if err != nil {
return nil, err
}
}

matches := []string{}
// find filesystem migrations
files, err := fs.ReadDir(db.MigrationsFS, filepath.Clean(db.MigrationsDir))
if err != nil {
return nil, fmt.Errorf("%w `%s`", ErrMigrationDirNotFound, db.MigrationsDir)
}

migrations := []Migration{}
for _, file := range files {
if file.IsDir() {
continue
}

name := file.Name()
if !re.MatchString(name) {
matches := migrationFileRegexp.FindStringSubmatch(file.Name())
if len(matches) < 2 {
continue
}

matches = append(matches, name)
}

sort.Strings(matches)

return matches, nil
}

func findMigrationFile(dir string, ver string) (string, error) {
if ver == "" {
panic("migration version is required")
}

ver = regexp.QuoteMeta(ver)
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))

files, err := findMigrationFiles(dir, re)
if err != nil {
return "", err
}
migration := Migration{
Applied: false,
FileName: matches[0],
FilePath: filepath.Join(db.MigrationsDir, matches[0]),
FS: db.MigrationsFS,
Version: matches[1],
}
if ok := appliedMigrations[migration.Version]; ok {
migration.Applied = true
}

if len(files) == 0 {
return "", fmt.Errorf("%w: %s*.sql", ErrMigrationNotFound, ver)
migrations = append(migrations, migration)
}

return files[0], nil
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].FileName < migrations[j].FileName
})

func migrationVersion(filename string) string {
return regexp.MustCompile(`^\d+`).FindString(filename)
return migrations, nil
}

// Rollback rolls back the most recent migration
Expand All @@ -436,46 +449,44 @@ func (db *DB) Rollback() error {
}
defer dbutil.MustClose(sqlDB)

applied, err := drv.SelectMigrations(sqlDB, 1)
// find last applied migration
var latest *Migration
migrations, err := db.FindMigrations()
if err != nil {
return err
}

// grab most recent applied migration (applied has len=1)
latest := ""
for ver := range applied {
latest = ver
}
if latest == "" {
return ErrNoRollback
for _, migration := range migrations {
if migration.Applied {
latest = &migration
}
}

filename, err := findMigrationFile(db.MigrationsDir, latest)
if err != nil {
return err
if latest == nil {
return ErrNoRollback
}

fmt.Fprintf(db.Log, "Rolling back: %s\n", filename)
fmt.Fprintf(db.Log, "Rolling back: %s\n", latest.FileName)

_, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
parsed, err := latest.Parse()
if err != nil {
return err
}

execMigration := func(tx dbutil.Transaction) error {
// rollback migration
result, err := tx.Exec(down.Contents)
result, err := tx.Exec(parsed.Down)
if err != nil {
return err
} else if db.Verbose {
db.printVerbose(result)
}

// remove migration record
return drv.DeleteMigration(tx, latest)
return drv.DeleteMigration(tx, latest.Version)
}

if down.Options.Transaction() {
if parsed.DownOptions.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
Expand All @@ -497,7 +508,7 @@ func (db *DB) Rollback() error {

// Status shows the status of all migrations
func (db *DB) Status(quiet bool) (int, error) {
results, err := db.CheckMigrationsStatus()
results, err := db.FindMigrations()
if err != nil {
return -1, err
}
Expand All @@ -507,10 +518,10 @@ func (db *DB) Status(quiet bool) (int, error) {

for _, res := range results {
if res.Applied {
line = fmt.Sprintf("[X] %s", res.Filename)
line = fmt.Sprintf("[X] %s", res.FileName)
totalApplied++
} else {
line = fmt.Sprintf("[ ] %s", res.Filename)
line = fmt.Sprintf("[ ] %s", res.FileName)
}
if !quiet {
fmt.Fprintln(db.Log, line)
Expand All @@ -526,56 +537,3 @@ func (db *DB) Status(quiet bool) (int, error) {

return totalPending, nil
}

// CheckMigrationsStatus returns the status of all available mgirations
func (db *DB) CheckMigrationsStatus() ([]StatusResult, error) {
drv, err := db.Driver()
if err != nil {
return nil, err
}

files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
if err != nil {
return nil, err
}

if len(files) == 0 {
return nil, ErrNoMigrationFiles
}

sqlDB, err := drv.Open()
if err != nil {
return nil, err
}
defer dbutil.MustClose(sqlDB)

applied := map[string]bool{}

migrationsTableExists, err := drv.MigrationsTableExists(sqlDB)
if err != nil {
return nil, err
}

if migrationsTableExists {
applied, err = drv.SelectMigrations(sqlDB, -1)
if err != nil {
return nil, err
}
}

var results []StatusResult

for _, filename := range files {
ver := migrationVersion(filename)
res := StatusResult{Filename: filename}
if ok := applied[ver]; ok {
res.Applied = true
} else {
res.Applied = false
}

results = append(results, res)
}

return results, nil
}
Loading

0 comments on commit be85b6b

Please sign in to comment.