Skip to content

Commit

Permalink
ANZX-146383 - Adding Verbose command (#51)
Browse files Browse the repository at this point in the history
* adding verbose command

* updating README

* new line at end of output

* adding RowsAffected to non-partitionedDML

* riunning readme generator

* sorting output by key
  • Loading branch information
SamMcEachern committed Mar 7, 2024
1 parent ef8889c commit 1e586c3
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 28 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Flags:
--schema-file string Name of schema file (optional. if not set, will use default 'schema.sql' file name)
--sequence-interval uint16 Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1) (default 1)
--static-data-tables-file string File containing list of static data tables to track (optional)
--verbose Used to indicate whether to output Migration information during a migration
-v, --version version for wrench
Use "wrench [command] --help" for more information about a command.
Expand Down
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
flagNameSchemaFile = "schema-file"
flagLockIdentifier = "lock-identifier"
flagSequenceInterval = "sequence-interval"
flagVerbose = "verbose"
flagDDLFile = "ddl"
flagDMLFile = "dml"
flagPartitioned = "partitioned"
Expand Down
17 changes: 15 additions & 2 deletions cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,30 @@ func migrateUp(c *cobra.Command, args []string) error {
err: err,
}
}

var migrationsOutput spanner.MigrationsOutput
switch status {
case spanner.ExistingMigrationsUpgradeStarted:
return client.UpgradeExecuteMigrations(ctx, migrations, limit, migrationTableName)
migrationsOutput, err = client.UpgradeExecuteMigrations(ctx, migrations, limit, migrationTableName)
if err != nil {
return err
}
case spanner.ExistingMigrationsUpgradeCompleted:
return client.ExecuteMigrations(ctx, migrations, limit, migrationTableName)
migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, limit, migrationTableName)
if err != nil {
return err
}
default:
return &Error{
cmd: c,
err: errors.New("migration in undetermined state"),
}
}
if verbose {
fmt.Print(migrationsOutput.String())
}

return nil
}

func migrateVersion(c *cobra.Command, args []string) error {
Expand Down
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
staticDataTablesFile string
lockIdentifier string
sequenceInterval uint16
verbose bool
)

var rootCmd = &cobra.Command{
Expand Down Expand Up @@ -84,6 +85,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&staticDataTablesFile, flagStaticDataTablesFile, "", "File containing list of static data tables to track (optional)")
rootCmd.PersistentFlags().StringVar(&lockIdentifier, flagLockIdentifier, getLockIdentifier(), "Random identifier used to lock migration operations to a single wrench process. (optional. if not set then it will be generated)")
rootCmd.PersistentFlags().Uint16Var(&sequenceInterval, flagSequenceInterval, getSequenceInterval(), "Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1)")
rootCmd.PersistentFlags().BoolVar(&verbose, flagVerbose, false, "Used to indicate whether to output Migration information during a migration")

rootCmd.Version = versioninfo.Version
rootCmd.SetVersionTemplate(versionTemplate)
Expand Down
81 changes: 62 additions & 19 deletions pkg/spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,18 +446,23 @@ func (c *Client) ApplyPartitionedDML(ctx context.Context, statements []string) (
return numAffectedRows, nil
}

func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) error {
func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) {
err := c.backfillMigrations(ctx, migrations, tableName)
if err != nil {
return err
return nil, err
}

err = c.ExecuteMigrations(ctx, migrations, limit, tableName)
migrationsOutput, err := c.ExecuteMigrations(ctx, migrations, limit, tableName)
if err != nil {
return err
return nil, err
}

return c.markUpgradeComplete(ctx)
err = c.markUpgradeComplete(ctx)
if err != nil {
return nil, err
}

return migrationsOutput, nil
}

func (c *Client) backfillMigrations(ctx context.Context, migrations Migrations, tableName string) error {
Expand Down Expand Up @@ -549,30 +554,57 @@ func (c *Client) GetMigrationHistory(ctx context.Context, versionTableName strin
return history, nil
}

func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) error {
type MigrationsOutput map[string]migrationInfo

type migrationInfo struct {
RowsAffected int64
}

func (i MigrationsOutput) String() string {
if len(i) == 0 {
return ""
}

var filenames []string
for filename := range i {
filenames = append(filenames, filename)
}

sort.StringSlice(filenames).Sort()

output := "Migration Information:"
for _, filename := range filenames {
migrationInfo := i[filename]
output = fmt.Sprintf("%s\n%s - rows affected: %d", output, filename, migrationInfo.RowsAffected)
}

return fmt.Sprintf("%s\n", output)
}

func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) {
sort.Sort(migrations)

version, dirty, err := c.GetSchemaMigrationVersion(ctx, tableName)
if err != nil {
var se *Error
if !errors.As(err, &se) || se.Code != ErrorCodeNoMigration {
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
}
}

if dirty {
return &Error{
return nil, &Error{
Code: ErrorCodeMigrationVersionDirty,
err: fmt.Errorf("Database version: %d is dirty, please fix it.", version),
err: fmt.Errorf("database version: %d is dirty, please fix it.", version),
}
}

history, err := c.GetMigrationHistory(ctx, tableName)
if err != nil {
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
Expand All @@ -582,6 +614,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
applied[history[i].Version] = true
}

var migrationsOutput MigrationsOutput = make(MigrationsOutput)
var count int
for _, m := range migrations {
if limit == 0 {
Expand All @@ -593,7 +626,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
}

if err := c.setSchemaMigrationVersion(ctx, m.Version, true, tableName); err != nil {
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
Expand All @@ -602,27 +635,37 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
switch m.kind {
case statementKindDDL:
if err := c.ApplyDDL(ctx, m.Statements); err != nil {
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
}
case statementKindDML:
if _, err := c.ApplyDML(ctx, m.Statements); err != nil {
return &Error{
rowsAffected, err := c.ApplyDML(ctx, m.Statements)
if err != nil {
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
}

migrationsOutput[m.FileName] = migrationInfo{
RowsAffected: rowsAffected,
}
case statementKindPartitionedDML:
if _, err := c.ApplyPartitionedDML(ctx, m.Statements); err != nil {
return &Error{
rowsAffected, err := c.ApplyPartitionedDML(ctx, m.Statements)
if err != nil {
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
}

migrationsOutput[m.FileName] = migrationInfo{
RowsAffected: rowsAffected,
}
default:
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: fmt.Errorf("Unknown query type, version: %d", m.Version),
}
Expand All @@ -635,7 +678,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
}

if err := c.setSchemaMigrationVersion(ctx, m.Version, false, tableName); err != nil {
return &Error{
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
Expand All @@ -651,7 +694,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
fmt.Println("no change")
}

return nil
return migrationsOutput, nil
}

func (c *Client) GetSchemaMigrationVersion(ctx context.Context, tableName string) (uint, bool, error) {
Expand Down
76 changes: 69 additions & 7 deletions pkg/spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,29 @@ func TestExecuteMigrations(t *testing.T) {
t.Fatalf("failed to load migrations: %v", err)
}

var migrationsOutput MigrationsOutput
// only apply 000002.sql by specifying limit 1.
if err := client.ExecuteMigrations(ctx, migrations, 1, migrationTable); err != nil {
if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, 1, migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}

if len(migrationsOutput) != 0 {
t.Errorf("want zero length migrationInfo, but got %v", len(migrationsOutput))
}

// ensure that only 000002.sql has been applied.
ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "YES")
ensureMigrationVersionRecord(t, ctx, client, 2, false)
ensureMigrationHistoryRecord(t, ctx, client, 2, false)

if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}

if want, got := int64(1), migrationsOutput["000003.sql"].RowsAffected; want != got {
t.Errorf("want %d, but got %d", want, got)
}

// ensure that 000003.sql and 000004.sql have been applied.
ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO")
ensureMigrationVersionRecord(t, ctx, client, 4, false)
Expand Down Expand Up @@ -499,7 +508,7 @@ func TestHotfixMigration(t *testing.T) {
if err != nil {
t.Fatalf("failed to load migrations: %v", err)
}
if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
if _, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}
history, err := client.GetMigrationHistory(ctx, migrationTable)
Expand All @@ -517,7 +526,7 @@ func TestHotfixMigration(t *testing.T) {
if err != nil {
t.Fatalf("failed to load migrations: %v", err)
}
if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
if _, err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}
history, err = client.GetMigrationHistory(ctx, migrationTable)
Expand All @@ -541,7 +550,7 @@ func TestUpgrade(t *testing.T) {
if err != nil {
t.Fatalf("failed to load migrations: %v", err)
}
if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
if _, err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}
expected, err := client.GetMigrationHistory(ctx, migrationTable)
Expand All @@ -559,7 +568,7 @@ func TestUpgrade(t *testing.T) {
if client.tableExists(ctx, upgradeIndicator) == false {
t.Error("upgrade indicator should exist")
}
if err := client.UpgradeExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
if _, err := client.UpgradeExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}

Expand Down Expand Up @@ -769,11 +778,64 @@ func TestClient_RepairMigration(t *testing.T) {
}
}

func Test_MigrationInfoString(t *testing.T) {
tests := []struct {
testName string
migrationInfo MigrationsOutput
exptectedOutput string
}{
{
testName: "no results",
migrationInfo: MigrationsOutput{},
exptectedOutput: "",
},
{
testName: "unitiated results - panic resiliant",
exptectedOutput: "",
},
{
testName: "one result",
migrationInfo: MigrationsOutput{
"i-deleted-everything.sql": migrationInfo{
RowsAffected: 2000,
},
},
exptectedOutput: "Migration Information:\ni-deleted-everything.sql - rows affected: 2000\n",
},
{
testName: "many results",
migrationInfo: MigrationsOutput{
"0001-i-am-a-cool-update.sql": migrationInfo{
RowsAffected: 20,
},
"0002-not-as-cool-as-me.sql": migrationInfo{
RowsAffected: 25,
},
"0003-i-deleted-everything.sql": migrationInfo{
RowsAffected: 2000,
},
},
exptectedOutput: "Migration Information:\n0001-i-am-a-cool-update.sql - rows affected: 20\n0002-not-as-cool-as-me.sql - rows affected: 25\n0003-i-deleted-everything.sql - rows affected: 2000\n",
},
}

for _, test := range tests {
output := test.migrationInfo.String()
assert.Equal(t, test.exptectedOutput, output)
}
}

func migrateUpDir(t *testing.T, ctx context.Context, client *Client, dir string, toSkip ...uint) error {
t.Helper()
migrations, err := LoadMigrations(dir, toSkip)
if err != nil {
return err
}
return client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable)

_, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable)
if err != nil {
return err
}

return nil
}
4 changes: 4 additions & 0 deletions pkg/spanner/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ type (
// Name is the name of the migration
Name string

// FileName is the name of the source file for the migration
FileName string

// Statements is the migration statements
Statements []string

Expand Down Expand Up @@ -138,6 +141,7 @@ func LoadMigrations(dir string, toSkipSlice []uint) (Migrations, error) {
migrations = append(migrations, &Migration{
Version: uint(version),
Name: matches[2],
FileName: f.Name(),
Statements: statements,
kind: kind,
})
Expand Down

0 comments on commit 1e586c3

Please sign in to comment.