From b3b1ccd46dca62c7328471eee106e2afd6bea32b Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Fri, 11 Aug 2023 00:05:44 +0300 Subject: [PATCH] sql/migrate: add support for directory checkpoints (#1971) --- cmd/atlas/internal/cmdapi/migrate.go | 6 - doc/md/reference.md | 1 - sql/migrate/dir.go | 293 ++++++++++++++++++++++----- sql/migrate/dir_test.go | 140 +++++++++++++ sql/migrate/migrate.go | 38 +--- sql/migrate/migrate_test.go | 37 ---- 6 files changed, 388 insertions(+), 127 deletions(-) diff --git a/cmd/atlas/internal/cmdapi/migrate.go b/cmd/atlas/internal/cmdapi/migrate.go index 344a8a79c75..d73a5044d3c 100644 --- a/cmd/atlas/internal/cmdapi/migrate.go +++ b/cmd/atlas/internal/cmdapi/migrate.go @@ -75,7 +75,6 @@ type migrateApplyFlags struct { logFormat string lockTimeout time.Duration allowDirty bool // allow working on a database that already has resources - fromVersion string // compute pending files based on this version baselineVersion string // apply with this version as baseline txMode string // (none, file, all) } @@ -87,9 +86,6 @@ func (f *migrateApplyFlags) migrateOptions() (opts []migrate.ExecutorOption) { if v := f.baselineVersion; v != "" { opts = append(opts, migrate.WithBaselineVersion(v)) } - if v := f.fromVersion; v != "" { - opts = append(opts, migrate.WithFromVersion(v)) - } return } @@ -141,11 +137,9 @@ If run with the "--dry-run" flag, atlas will not execute any SQL.`, addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema) addFlagDryRun(cmd.Flags(), &flags.dryRun) addFlagLockTimeout(cmd.Flags(), &flags.lockTimeout) - cmd.Flags().StringVarP(&flags.fromVersion, flagFrom, "", "", "calculate pending files from the given version (including it)") cmd.Flags().StringVarP(&flags.baselineVersion, flagBaseline, "", "", "start the first migration after the given baseline version") cmd.Flags().StringVarP(&flags.txMode, flagTxMode, "", txModeFile, "set transaction mode [none, file, all]") cmd.Flags().BoolVarP(&flags.allowDirty, flagAllowDirty, "", false, "allow start working on a non-clean database") - cmd.MarkFlagsMutuallyExclusive(flagFrom, flagBaseline) cmd.MarkFlagsMutuallyExclusive(flagLog, flagFormat) return cmd } diff --git a/doc/md/reference.md b/doc/md/reference.md index dcf4ca3bcc0..2d4be65e998 100644 --- a/doc/md/reference.md +++ b/doc/md/reference.md @@ -92,7 +92,6 @@ If run with the "--dry-run" flag, atlas will not execute any SQL. --revisions-schema string name of the schema the revisions table resides in --dry-run print SQL without executing it --lock-timeout duration set how long to wait for the database lock (default 10s) - --from string calculate pending files from the given version (including it) --baseline string start the first migration after the given baseline version --tx-mode string set transaction mode [none, file, all] (default "file") --allow-dirty allow start working on a non-clean database diff --git a/sql/migrate/dir.go b/sql/migrate/dir.go index 06bef37380b..6192b892897 100644 --- a/sql/migrate/dir.go +++ b/sql/migrate/dir.go @@ -38,12 +38,6 @@ type ( Checksum() (HashFile, error) } - // Formatter wraps the Format method. - Formatter interface { - // Format formats the given Plan into one or more migration files. - Format(*Plan) ([]File, error) - } - // File represents a single migration file. File interface { // Name returns the name of the migration file. @@ -59,6 +53,50 @@ type ( // StmtDecls returns the set of SQL statements this file holds alongside its preceding comments. StmtDecls() ([]*Stmt, error) } + + // Formatter wraps the Format method. + Formatter interface { + // Format formats the given Plan into one or more migration files. + Format(*Plan) ([]File, error) + } +) + +type ( + // CheckpointDir wraps the functionality used to interact + // with a migration directory that support checkpoints. + CheckpointDir interface { + // WriteCheckpoint writes the given checkpoint file to the migration directory. + WriteCheckpoint(name, tag string, content []byte) error + + // CheckpointFiles returns a set of checkpoint files stored in this Dir, + // ordered by name. + CheckpointFiles() ([]File, error) + + // FilesFromCheckpoint returns the files to be executed on a database from + // the given checkpoint file, including it. An ErrCheckpointNotFound if the + // checkpoint file is not found in the directory. + FilesFromCheckpoint(string) ([]File, error) + } + + // CheckpointFile wraps the functionality used to interact with files + // returned from a CheckpointDir. + CheckpointFile interface { + // IsCheckpoint returns true if the file is a checkpoint file. + IsCheckpoint() bool + + // CheckpointTag returns the tag of the checkpoint file, if defined. The tag + // can be derived from the file name, internal metadata or directive comments. + // + // An ErrNotCheckpoint is returned if the file is not a checkpoint file. + CheckpointTag() (string, error) + } +) + +var ( + // ErrNotCheckpoint is returned when calling CheckpointFile methods on a non-checkpoint file. + ErrNotCheckpoint = errors.New("not a checkpoint file") + // ErrCheckpointNotFound is returned when a checkpoint file is not found in the directory. + ErrCheckpointNotFound = errors.New("no checkpoint found") ) // LocalDir implements Dir for a local migration @@ -67,7 +105,10 @@ type LocalDir struct { path string } -var _ Dir = (*LocalDir)(nil) +var _ interface { + Dir + CheckpointDir +} = (*LocalDir)(nil) // NewLocalDir returns a new the Dir used by a Planner to work on the given local path. func NewLocalDir(path string) (*LocalDir, error) { @@ -110,15 +151,15 @@ func (d *LocalDir) Files() ([]File, error) { sort.Slice(names, func(i, j int) bool { return names[i] < names[j] }) - ret := make([]File, len(names)) - for i, n := range names { + files := make([]File, 0, len(names)) + for _, n := range names { b, err := fs.ReadFile(d, n) if err != nil { return nil, fmt.Errorf("sql/migrate: read file %q: %w", n, err) } - ret[i] = NewLocalFile(n, b) + files = append(files, NewLocalFile(n, b)) } - return ret, nil + return files, nil } // Checksum implements Dir.Checksum. By default, it calls Files() and creates a checksum from them. @@ -130,13 +171,33 @@ func (d *LocalDir) Checksum() (HashFile, error) { return NewHashFile(files) } +// WriteCheckpoint is like WriteFile, but marks the file as a checkpoint file. +func (d *LocalDir) WriteCheckpoint(name, tag string, b []byte) error { + f := NewLocalFile(name, b) + f.AddDirective(directiveCheckpoint, tag) + return d.WriteFile(name, f.Bytes()) +} + +// CheckpointFiles implements CheckpointDir.CheckpointFiles. +func (d *LocalDir) CheckpointFiles() ([]File, error) { + return checkpointFiles(d) +} + +// FilesFromCheckpoint implements CheckpointDir.FilesFromCheckpoint. +func (d *LocalDir) FilesFromCheckpoint(name string) ([]File, error) { + return filesFromCheckpoint(d, name) +} + // LocalFile is used by LocalDir to implement the Scanner interface. type LocalFile struct { n string b []byte } -var _ File = (*LocalFile)(nil) +var _ interface { + File + CheckpointFile +} = (*LocalFile)(nil) // NewLocalFile returns a new local file. func NewLocalFile(name string, data []byte) *LocalFile { @@ -144,12 +205,12 @@ func NewLocalFile(name string, data []byte) *LocalFile { } // Name implements File.Name. -func (f LocalFile) Name() string { +func (f *LocalFile) Name() string { return f.n } // Desc implements File.Desc. -func (f LocalFile) Desc() string { +func (f *LocalFile) Desc() string { parts := strings.SplitN(f.n, "_", 2) if len(parts) == 1 { return "" @@ -158,12 +219,12 @@ func (f LocalFile) Desc() string { } // Version implements File.Version. -func (f LocalFile) Version() string { +func (f *LocalFile) Version() string { return strings.SplitN(strings.TrimSuffix(f.n, ".sql"), "_", 2)[0] } // Stmts returns the SQL statement exists in the local file. -func (f LocalFile) Stmts() ([]string, error) { +func (f *LocalFile) Stmts() ([]string, error) { s, err := Stmts(string(f.b)) if err != nil { return nil, err @@ -176,19 +237,85 @@ func (f LocalFile) Stmts() ([]string, error) { } // StmtDecls returns the all statement declarations exist in the local file. -func (f LocalFile) StmtDecls() ([]*Stmt, error) { +func (f *LocalFile) StmtDecls() ([]*Stmt, error) { return Stmts(string(f.b)) } // Bytes returns local file data. -func (f LocalFile) Bytes() []byte { +func (f *LocalFile) Bytes() []byte { return f.b } +// IsCheckpoint reports whether the file is a checkpoint file. +func (f *LocalFile) IsCheckpoint() bool { + return len(f.Directive(directiveCheckpoint)) > 0 +} + +// CheckpointTag returns the tag of the checkpoint file, if defined. +func (f *LocalFile) CheckpointTag() (string, error) { + ds := f.Directive(directiveCheckpoint) + if len(ds) == 0 { + return "", ErrNotCheckpoint + } + return ds[0], nil +} + +const ( + // atlas:sum directive. + directiveSum = "sum" + sumModeIgnore = "ignore" + // atlas:delimiter directive. + directiveDelimiter = "delimiter" + // atlas:checkpoint directive. + directiveCheckpoint = "checkpoint" + directivePrefixSQL = "-- " +) + +var reDirective = regexp.MustCompile(`^([ -~]*)atlas:(\w+)(?: +([ -~]*))*`) + +// directive searches in the content a line that matches a directive +// with the given prefix and name. For example: +// +// directive(c, "delimiter", "-- ") // '-- atlas:delimiter.*' +// directive(c, "sum", "") // 'atlas:sum.*' +// directive(c, "sum") // '.*atlas:sum' +func directive(content, name string, prefix ...string) (string, bool) { + m := reDirective.FindStringSubmatch(content) + // In case the prefix was provided ensures it is matched. + if len(m) == 4 && m[2] == name && (len(prefix) == 0 || prefix[0] == m[1]) { + return m[3], true + } + return "", false +} + // Directive returns the (global) file directives that match the provided name. // File directives are located at the top of the file and should not be associated with any // statement. Hence, double new lines are used to separate file directives from its content. -func (f LocalFile) Directive(name string) (ds []string) { +func (f *LocalFile) Directive(name string) (ds []string) { + for _, c := range f.comments() { + if d, ok := directive(c, name); ok { + ds = append(ds, d) + } + } + return ds +} + +// AddDirective adds a new directive to the file. +func (f *LocalFile) AddDirective(name string, args ...string) { + var b strings.Builder + b.WriteString("-- atlas:" + name) + if len(args) > 0 { + b.WriteByte(' ') + b.WriteString(strings.Join(args, " ")) + } + b.WriteByte('\n') + if len(f.comments()) == 0 { + b.WriteByte('\n') + } + f.b = append([]byte(b.String()), f.b...) +} + +func (f *LocalFile) comments() []string { var ( comments []string content = string(f.b) @@ -203,23 +330,18 @@ func (f LocalFile) Directive(name string) (ds []string) { comments = append(comments, strings.TrimSpace(content[:idx])) content = content[idx+1:] } - // File directives are separated by - // double newlines from file content. + // File comments are separated by double newlines from + // file content (detached from actual statements). if !strings.HasPrefix(content, "\n") { return nil } - for _, c := range comments { - if d, ok := directive(c, name); ok { - ds = append(ds, d) - } - } - return ds + return comments } type ( // MemDir provides an in-memory Dir implementation. MemDir struct { - files map[string]File + files map[string]*LocalFile syncTo []func(string, []byte) error } // An opened MemDir. @@ -229,11 +351,17 @@ type ( } ) -// A list of the opened memory-based directories. -var memDirs struct { - sync.Mutex - opened map[string]*openedMem -} +var ( + // A list of the opened memory-based directories. + memDirs struct { + sync.Mutex + opened map[string]*openedMem + } + _ interface { + Dir + CheckpointDir + } = (*MemDir)(nil) +) // OpenMemDir opens an in-memory directory and registers it in the process namespace // with the given name. Hence, calling OpenMemDir with the same name will return the @@ -292,7 +420,7 @@ func (d *MemDir) Close() error { // WriteFile adds a new file in-memory. func (d *MemDir) WriteFile(name string, data []byte) error { if d.files == nil { - d.files = make(map[string]File) + d.files = make(map[string]*LocalFile) } d.files[name] = NewLocalFile(name, data) for _, f := range d.syncTo { @@ -303,6 +431,23 @@ func (d *MemDir) WriteFile(name string, data []byte) error { return nil } +// WriteCheckpoint is like WriteFile, but marks the file as a checkpoint file. +func (d *MemDir) WriteCheckpoint(name, tag string, b []byte) error { + f := NewLocalFile(name, b) + f.AddDirective(directiveCheckpoint, tag) + return d.WriteFile(name, f.Bytes()) +} + +// CheckpointFiles implements CheckpointDir.CheckpointFiles. +func (d *MemDir) CheckpointFiles() ([]File, error) { + return checkpointFiles(d) +} + +// FilesFromCheckpoint implements CheckpointDir.FilesFromCheckpoint. +func (d *MemDir) FilesFromCheckpoint(name string) ([]File, error) { + return filesFromCheckpoint(d, name) +} + // SyncWrites allows syncing writes from in-memory directory // the underlying storage. func (d *MemDir) SyncWrites(fs ...func(string, []byte) error) { @@ -512,7 +657,7 @@ func Validate(dir Dir) error { // FilesLastIndex returns the index of the last file // satisfying f(i), or -1 if none do. -func FilesLastIndex(files []File, f func(File) bool) int { +func FilesLastIndex[F File](files []F, f func(F) bool) int { for i := len(files) - 1; i >= 0; i-- { if f(files[i]) { return i @@ -521,30 +666,66 @@ func FilesLastIndex(files []File, f func(File) bool) int { return -1 } -const ( - // atlas:sum directive. - directiveSum = "sum" - sumModeIgnore = "ignore" - // atlas:delimiter directive. - directiveDelimiter = "delimiter" - directivePrefixSQL = "-- " -) +// SkipCheckpointFiles returns a filtered set of files that are not checkpoint files. +func SkipCheckpointFiles(all []File) []File { + files := make([]File, 0, len(all)) + for _, f := range all { + if ck, ok := f.(CheckpointFile); ok && ck.IsCheckpoint() { + continue + } + files = append(files, f) + } + return files +} -var reDirective = regexp.MustCompile(`^([ -~]*)atlas:(\w+)(?: +([ -~]*))*`) +// FilesFromLastCheckpoint returns a set of files created after the last checkpoint, +// if exists, to be executed on a database (on the first time). Note, if the Dir is +// not a CheckpointDir, or no checkpoint file was found, all files are returned. +func FilesFromLastCheckpoint(dir Dir) ([]File, error) { + ck, ok := dir.(CheckpointDir) + if !ok { + return dir.Files() + } + cks, err := ck.CheckpointFiles() + if err != nil { + return nil, err + } + if len(cks) == 0 { + return dir.Files() + } + return ck.FilesFromCheckpoint(cks[len(cks)-1].Name()) +} -// directive searches in the content a line that matches a directive -// with the given prefix and name. For example: -// -// directive(c, "delimiter", "-- ") // '-- atlas:delimiter.*' -// directive(c, "sum", "") // 'atlas:sum.*' -// directive(c, "sum") // '.*atlas:sum' -func directive(content, name string, prefix ...string) (string, bool) { - m := reDirective.FindStringSubmatch(content) - // In case the prefix was provided ensures it is matched. - if len(m) == 4 && m[2] == name && (len(prefix) == 0 || prefix[0] == m[1]) { - return m[3], true +// checkpointFiles returns all checkpoint files in a migration directory. +func checkpointFiles(d Dir) ([]File, error) { + files, err := d.Files() + if err != nil { + return nil, err } - return "", false + var cks []File + for _, f := range files { + if ck, ok := f.(CheckpointFile); ok && ck.IsCheckpoint() { + cks = append(cks, f) + } + } + return cks, nil +} + +// filesFromCheckpoint returns all files from the given checkpoint +// to be executed on the database, including the checkpoint file. +func filesFromCheckpoint(d Dir, name string) ([]File, error) { + files, err := d.Files() + if err != nil { + return nil, err + } + i := FilesLastIndex(files, func(f File) bool { + c, ok := f.(CheckpointFile) + return ok && c.IsCheckpoint() && f.Name() == name + }) + if i == -1 { + return nil, ErrCheckpointNotFound + } + return files[i:], nil } // readHashFile reads the HashFile from the given Dir. diff --git a/sql/migrate/dir_test.go b/sql/migrate/dir_test.go index 443c7a13c45..4942b452fcf 100644 --- a/sql/migrate/dir_test.go +++ b/sql/migrate/dir_test.go @@ -199,6 +199,92 @@ func TestLocalDir(t *testing.T) { require.Equal(t, "description", files[1].Desc()) } +func TestCheckpointDir(t *testing.T) { + local, err := migrate.NewLocalDir(t.TempDir()) + require.NoError(t, err) + for _, d := range []interface { + migrate.Dir + migrate.CheckpointDir + }{&migrate.MemDir{}, local} { + files, err := d.Files() + require.NoError(t, err) + require.Empty(t, files) + cks, err := d.CheckpointFiles() + require.NoError(t, err) + require.Empty(t, cks) + require.NoError(t, migrate.Validate(d)) + + require.NoError(t, d.WriteFile("1.sql", []byte("create table t1(c int);"))) + sum, err := d.Checksum() + require.NoError(t, err) + require.NoError(t, migrate.WriteSumFile(d, sum)) + require.NoError(t, migrate.Validate(d)) + files, err = d.Files() + require.NoError(t, err) + require.Len(t, files, 1) + cks, err = d.CheckpointFiles() + require.NoError(t, err) + require.Empty(t, cks) + + require.NoError(t, d.WriteCheckpoint("2_checkpoint.sql", "", []byte("create table t1(c int);"))) + sum, err = d.Checksum() + require.NoError(t, err) + require.NoError(t, migrate.WriteSumFile(d, sum)) + require.NoError(t, migrate.Validate(d)) + files, err = d.Files() + require.NoError(t, err) + require.Len(t, files, 2) + require.Equal(t, []string{"1.sql", "2_checkpoint.sql"}, []string{files[0].Name(), files[1].Name()}) + files = migrate.SkipCheckpointFiles(files) + require.Len(t, files, 1) + require.Equal(t, "1.sql", files[0].Name()) + cks, err = d.CheckpointFiles() + require.NoError(t, err) + require.Len(t, cks, 1) + require.Equal(t, "2_checkpoint.sql", cks[0].Name()) + + require.NoError(t, d.WriteFile("3.sql", []byte("create table t2(c int);"))) + sum, err = d.Checksum() + require.NoError(t, err) + require.NoError(t, migrate.WriteSumFile(d, sum)) + require.NoError(t, migrate.Validate(d)) + files, err = d.Files() + require.NoError(t, err) + require.Len(t, files, 3) + require.Equal(t, []string{"1.sql", "2_checkpoint.sql", "3.sql"}, []string{files[0].Name(), files[1].Name(), files[2].Name()}) + files = migrate.SkipCheckpointFiles(files) + require.Len(t, files, 2) + require.Equal(t, []string{"1.sql", "3.sql"}, []string{files[0].Name(), files[1].Name()}) + cks, err = d.CheckpointFiles() + require.NoError(t, err) + require.Len(t, cks, 1) + require.Equal(t, "2_checkpoint.sql", cks[0].Name()) + + require.NoError(t, d.WriteCheckpoint("4_checkpoint.sql", "v4", []byte("create table t1(c int);\ncreate table t2(c int);"))) + sum, err = d.Checksum() + require.NoError(t, err) + require.NoError(t, migrate.WriteSumFile(d, sum)) + require.NoError(t, migrate.Validate(d)) + files, err = d.Files() + require.NoError(t, err) + require.Len(t, files, 4) + require.Equal(t, []string{"1.sql", "2_checkpoint.sql", "3.sql", "4_checkpoint.sql"}, []string{files[0].Name(), files[1].Name(), files[2].Name(), files[3].Name()}) + files = migrate.SkipCheckpointFiles(files) + require.Len(t, files, 2) + require.Equal(t, []string{"1.sql", "3.sql"}, []string{files[0].Name(), files[1].Name()}) + cks, err = d.CheckpointFiles() + require.NoError(t, err) + require.Len(t, cks, 2) + require.Equal(t, []string{"2_checkpoint.sql", "4_checkpoint.sql"}, []string{cks[0].Name(), cks[1].Name()}) + tag, err := cks[0].(migrate.CheckpointFile).CheckpointTag() + require.NoError(t, err) + require.Equal(t, "", tag) + tag, err = cks[1].(migrate.CheckpointFile).CheckpointTag() + require.NoError(t, err) + require.Equal(t, "v4", tag) + } +} + func TestMemDir(t *testing.T) { var d migrate.MemDir files, err := d.Files() @@ -305,6 +391,12 @@ alter table pets drop column id; require.Equal(t, []string{"ignore"}, f.Directive("lint"), "first directive from two") require.Equal(t, []string{"none"}, f.Directive("txmode"), "second directive from two") + f = migrate.NewLocalFile("1.sql", []byte(`-- atlas:nolint + +alter table users drop column id; +`)) + require.Equal(t, []string{""}, f.Directive("nolint"), "directives without arguments returned as empty string") + f = migrate.NewLocalFile("1.sql", nil) require.Empty(t, f.Directive("lint")) f = migrate.NewLocalFile("1.sql", []byte("-- atlas:lint ignore")) @@ -315,6 +407,54 @@ alter table pets drop column id; require.Equal(t, []string{"ignore"}, f.Directive("lint"), "double newline as directive separator") } +func TestLocalFile_AddDirective(t *testing.T) { + f := migrate.NewLocalFile("1.sql", []byte("SELECT 1;")) + f.AddDirective("lint", "ignore") + require.Equal(t, []string{"ignore"}, f.Directive("lint")) + require.Equal(t, "-- atlas:lint ignore\n\nSELECT 1;", string(f.Bytes())) + f.AddDirective("checkpoint") + require.Equal(t, []string{"ignore"}, f.Directive("lint")) + require.Equal(t, []string{""}, f.Directive("checkpoint")) + require.Equal(t, `-- atlas:checkpoint +-- atlas:lint ignore + +SELECT 1;`, string(f.Bytes())) + + f = migrate.NewLocalFile("1.sql", []byte("-- atlas:directive statement directive\nSELECT 1;")) + f.AddDirective("lint", "ignore") + require.Equal(t, []string{"ignore"}, f.Directive("lint")) + require.Equal(t, `-- atlas:lint ignore + +-- atlas:directive statement directive +SELECT 1;`, string(f.Bytes())) +} + +func TestLocalFile_CheckpointTag(t *testing.T) { + // Not a checkpoint. + for _, b := range []string{ + "SELECT 1;", + "-- atlas:checkpoint\nSELECT 1;", + "-- atlas:checkpoint tag\nSELECT 1;", + } { + f := migrate.NewLocalFile("1.sql", []byte(b)) + require.False(t, f.IsCheckpoint()) + tag, err := f.CheckpointTag() + require.ErrorIs(t, err, migrate.ErrNotCheckpoint) + require.Empty(t, tag) + } + // Checkpoint. + f := migrate.NewLocalFile("1.sql", []byte("-- atlas:checkpoint\n\nSELECT 1;")) + require.True(t, f.IsCheckpoint()) + tag, err := f.CheckpointTag() + require.NoError(t, err) + require.Empty(t, tag) + f = migrate.NewLocalFile("1.sql", []byte("-- atlas:checkpoint tag\n\nSELECT 1;")) + require.True(t, f.IsCheckpoint()) + tag, err = f.CheckpointTag() + require.NoError(t, err) + require.Equal(t, "tag", tag) +} + func TestDirTar(t *testing.T) { d := migrate.OpenMemDir("") defer d.Close() diff --git a/sql/migrate/migrate.go b/sql/migrate/migrate.go index ee718201faf..1c5777665be 100644 --- a/sql/migrate/migrate.go +++ b/sql/migrate/migrate.go @@ -239,7 +239,6 @@ type ( dir Dir // The Dir with migration files to use. rrw RevisionReadWriter // The RevisionReadWriter to read and write database revisions to. log Logger // The Logger to use. - fromVer string // Calculate pending files from the given version (including it). baselineVer string // Start the first migration after the given baseline version. allowDirty bool // Allow start working on a non-clean database. operator string // Revision.OperatorVersion @@ -461,7 +460,7 @@ func (p *Planner) WritePlan(plan *Plan) error { var ( // ErrNoPendingFiles is returned if there are no pending migration files to execute on the managed database. - ErrNoPendingFiles = errors.New("sql/migrate: execute: nothing to do") + ErrNoPendingFiles = errors.New("sql/migrate: no pending migration files") // ErrSnapshotUnsupported is returned if there is no Snapshoter given. ErrSnapshotUnsupported = errors.New("sql/migrate: driver does not support taking a database snapshot") // ErrCleanCheckerUnsupported is returned if there is no CleanChecker given. @@ -540,15 +539,6 @@ func WithLogger(log Logger) ExecutorOption { } } -// WithFromVersion allows passing a file version as a starting point for calculating -// pending migration scripts. It can be useful for skipping specific files. -func WithFromVersion(v string) ExecutorOption { - return func(ex *Executor) error { - ex.fromVer = v - return nil - } -} - // WithOperatorVersion sets the operator version to save on the revisions // when executing migration files. func WithOperatorVersion(v string) ExecutorOption { @@ -569,14 +559,11 @@ func (e *Executor) Pending(ctx context.Context) ([]File, error) { if err != nil { return nil, fmt.Errorf("sql/migrate: execute: read revisions: %w", err) } - // Select the correct migration files. migrations, err := e.dir.Files() if err != nil { return nil, fmt.Errorf("sql/migrate: execute: select migration files: %w", err) } - if len(migrations) == 0 { - return nil, ErrNoPendingFiles - } + migrations = SkipCheckpointFiles(migrations) var pending []File switch { // If it is the first time we run. @@ -589,7 +576,6 @@ func (e *Executor) Pending(ctx context.Context) ([]File, error) { if cerr != nil && !e.allowDirty && e.baselineVer == "" { return nil, fmt.Errorf("%w. baseline version or allow-dirty is required", cerr) } - pending = migrations if e.baselineVer != "" { baseline := FilesLastIndex(migrations, func(f File) bool { return f.Version() == e.baselineVer @@ -598,22 +584,20 @@ func (e *Executor) Pending(ctx context.Context) ([]File, error) { return nil, fmt.Errorf("baseline version %q not found", e.baselineVer) } f := migrations[baseline] - // Mark the revision in the database as baseline revision. + // Write the first revision in the database as a baseline revision. if err := e.writeRevision(ctx, &Revision{Version: f.Version(), Description: f.Desc(), Type: RevisionTypeBaseline}); err != nil { return nil, err } pending = migrations[baseline+1:] + + // In case the "allow-dirty" option was set, or the database is clean, + // the starting-point is the first migration file or the last checkpoint. + } else if pending, err = FilesFromLastCheckpoint(e.dir); err != nil { + return nil, err } - // Not the first time we execute and a custom starting point was provided. - case e.fromVer != "": - idx := FilesLastIndex(migrations, func(f File) bool { - return f.Version() == e.fromVer - }) - if idx == -1 { - return nil, fmt.Errorf("starting point version %q not found in the migration directory", e.fromVer) - } - pending = migrations[idx:] - default: + // In case we applied/marked revisions in + // the past, and there is work to do. + case len(migrations) > 0: var ( last = revs[len(revs)-1] partially = last.Applied != last.Total diff --git a/sql/migrate/migrate_test.go b/sql/migrate/migrate_test.go index e9d41a1c26f..95e54c64901 100644 --- a/sql/migrate/migrate_test.go +++ b/sql/migrate/migrate_test.go @@ -488,43 +488,6 @@ func TestExecutor_Baseline(t *testing.T) { require.Equal(t, migrate.RevisionTypeBaseline, rrw[0].Type) } -func TestExecutor_FromVersion(t *testing.T) { - var ( - drv = &mockDriver{} - log = &mockLogger{} - rrw = &mockRevisionReadWriter{ - { - Version: "1.a", - Description: "sub.up", - Applied: 2, - Total: 2, - Hash: "nXyZR020M/mH7LxkoTkJr7BcQkipVg90imQ9I4595dw=", - }, - } - ) - dir, err := migrate.NewLocalDir(filepath.Join("testdata/migrate", "sub")) - require.NoError(t, err) - ex, err := migrate.NewExecutor(drv, dir, rrw, migrate.WithLogger(log)) - require.NoError(t, err) - files, err := ex.Pending(context.Background()) - require.NoError(t, err) - require.Len(t, files, 2) - - // Control the starting point. - ex, err = migrate.NewExecutor(drv, dir, rrw, migrate.WithLogger(log), migrate.WithFromVersion("3")) - require.NoError(t, err) - files, err = ex.Pending(context.Background()) - require.NoError(t, err) - require.Len(t, files, 1) - - // Starting point was not found. - ex, err = migrate.NewExecutor(drv, dir, rrw, migrate.WithLogger(log), migrate.WithFromVersion("4")) - require.NoError(t, err) - files, err = ex.Pending(context.Background()) - require.EqualError(t, err, `starting point version "4" not found in the migration directory`) - require.Nil(t, files) -} - type ( mockDriver struct { migrate.Driver