diff --git a/README.md b/README.md index 125a23f..ee0a8ba 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ Flags: --schema-file string Name of schema file (optional. if not set, will use default 'schema.sql' file name) --sequence-interval uint16 Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1) (default 1) --static-data-tables-file string File containing list of static data tables to track (optional) + --verbose Used to indicate whether to output Migration information during a migration -v, --version version for wrench Use "wrench [command] --help" for more information about a command. diff --git a/cmd/cmd.go b/cmd/cmd.go index 7b4a764..a3ec443 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -40,6 +40,7 @@ const ( flagNameSchemaFile = "schema-file" flagLockIdentifier = "lock-identifier" flagSequenceInterval = "sequence-interval" + flagVerbose = "verbose" flagDDLFile = "ddl" flagDMLFile = "dml" flagPartitioned = "partitioned" diff --git a/cmd/migrate.go b/cmd/migrate.go index f9b4517..f33442a 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -211,17 +211,30 @@ func migrateUp(c *cobra.Command, args []string) error { err: err, } } + + var migrationsOutput spanner.MigrationsOutput switch status { case spanner.ExistingMigrationsUpgradeStarted: - return client.UpgradeExecuteMigrations(ctx, migrations, limit, migrationTableName) + migrationsOutput, err = client.UpgradeExecuteMigrations(ctx, migrations, limit, migrationTableName) + if err != nil { + return err + } case spanner.ExistingMigrationsUpgradeCompleted: - return client.ExecuteMigrations(ctx, migrations, limit, migrationTableName) + migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, limit, migrationTableName) + if err != nil { + return err + } default: return &Error{ cmd: c, err: errors.New("migration in undetermined state"), } } + if verbose { + fmt.Print(migrationsOutput.String()) + } + + return nil } func migrateVersion(c *cobra.Command, args []string) error { diff --git a/cmd/root.go b/cmd/root.go index 6546524..044298e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -46,6 +46,7 @@ var ( staticDataTablesFile string lockIdentifier string sequenceInterval uint16 + verbose bool ) var rootCmd = &cobra.Command{ @@ -84,6 +85,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&staticDataTablesFile, flagStaticDataTablesFile, "", "File containing list of static data tables to track (optional)") rootCmd.PersistentFlags().StringVar(&lockIdentifier, flagLockIdentifier, getLockIdentifier(), "Random identifier used to lock migration operations to a single wrench process. (optional. if not set then it will be generated)") rootCmd.PersistentFlags().Uint16Var(&sequenceInterval, flagSequenceInterval, getSequenceInterval(), "Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1)") + rootCmd.PersistentFlags().BoolVar(&verbose, flagVerbose, false, "Used to indicate whether to output Migration information during a migration") rootCmd.Version = versioninfo.Version rootCmd.SetVersionTemplate(versionTemplate) diff --git a/pkg/spanner/client.go b/pkg/spanner/client.go index e16754d..f6f74c1 100644 --- a/pkg/spanner/client.go +++ b/pkg/spanner/client.go @@ -446,18 +446,23 @@ func (c *Client) ApplyPartitionedDML(ctx context.Context, statements []string) ( return numAffectedRows, nil } -func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) error { +func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) { err := c.backfillMigrations(ctx, migrations, tableName) if err != nil { - return err + return nil, err } - err = c.ExecuteMigrations(ctx, migrations, limit, tableName) + migrationsOutput, err := c.ExecuteMigrations(ctx, migrations, limit, tableName) if err != nil { - return err + return nil, err } - return c.markUpgradeComplete(ctx) + err = c.markUpgradeComplete(ctx) + if err != nil { + return nil, err + } + + return migrationsOutput, nil } func (c *Client) backfillMigrations(ctx context.Context, migrations Migrations, tableName string) error { @@ -549,14 +554,41 @@ func (c *Client) GetMigrationHistory(ctx context.Context, versionTableName strin return history, nil } -func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) error { +type MigrationsOutput map[string]migrationInfo + +type migrationInfo struct { + RowsAffected int64 +} + +func (i MigrationsOutput) String() string { + if len(i) == 0 { + return "" + } + + var filenames []string + for filename := range i { + filenames = append(filenames, filename) + } + + sort.StringSlice(filenames).Sort() + + output := "Migration Information:" + for _, filename := range filenames { + migrationInfo := i[filename] + output = fmt.Sprintf("%s\n%s - rows affected: %d", output, filename, migrationInfo.RowsAffected) + } + + return fmt.Sprintf("%s\n", output) +} + +func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) { sort.Sort(migrations) version, dirty, err := c.GetSchemaMigrationVersion(ctx, tableName) if err != nil { var se *Error if !errors.As(err, &se) || se.Code != ErrorCodeNoMigration { - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } @@ -564,15 +596,15 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } if dirty { - return &Error{ + return nil, &Error{ Code: ErrorCodeMigrationVersionDirty, - err: fmt.Errorf("Database version: %d is dirty, please fix it.", version), + err: fmt.Errorf("database version: %d is dirty, please fix it.", version), } } history, err := c.GetMigrationHistory(ctx, tableName) if err != nil { - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } @@ -582,6 +614,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l applied[history[i].Version] = true } + var migrationsOutput MigrationsOutput = make(MigrationsOutput) var count int for _, m := range migrations { if limit == 0 { @@ -593,7 +626,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } if err := c.setSchemaMigrationVersion(ctx, m.Version, true, tableName); err != nil { - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } @@ -602,27 +635,37 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l switch m.kind { case statementKindDDL: if err := c.ApplyDDL(ctx, m.Statements); err != nil { - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } } case statementKindDML: - if _, err := c.ApplyDML(ctx, m.Statements); err != nil { - return &Error{ + rowsAffected, err := c.ApplyDML(ctx, m.Statements) + if err != nil { + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } } + + migrationsOutput[m.FileName] = migrationInfo{ + RowsAffected: rowsAffected, + } case statementKindPartitionedDML: - if _, err := c.ApplyPartitionedDML(ctx, m.Statements); err != nil { - return &Error{ + rowsAffected, err := c.ApplyPartitionedDML(ctx, m.Statements) + if err != nil { + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } } + + migrationsOutput[m.FileName] = migrationInfo{ + RowsAffected: rowsAffected, + } default: - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: fmt.Errorf("Unknown query type, version: %d", m.Version), } @@ -635,7 +678,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } if err := c.setSchemaMigrationVersion(ctx, m.Version, false, tableName); err != nil { - return &Error{ + return nil, &Error{ Code: ErrorCodeExecuteMigrations, err: err, } @@ -651,7 +694,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l fmt.Println("no change") } - return nil + return migrationsOutput, nil } func (c *Client) GetSchemaMigrationVersion(ctx context.Context, tableName string) (uint, bool, error) { diff --git a/pkg/spanner/client_test.go b/pkg/spanner/client_test.go index 2ec4d19..34f0b3d 100644 --- a/pkg/spanner/client_test.go +++ b/pkg/spanner/client_test.go @@ -210,20 +210,29 @@ func TestExecuteMigrations(t *testing.T) { t.Fatalf("failed to load migrations: %v", err) } + var migrationsOutput MigrationsOutput // only apply 000002.sql by specifying limit 1. - if err := client.ExecuteMigrations(ctx, migrations, 1, migrationTable); err != nil { + if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, 1, migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } + if len(migrationsOutput) != 0 { + t.Errorf("want zero length migrationInfo, but got %v", len(migrationsOutput)) + } + // ensure that only 000002.sql has been applied. ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "YES") ensureMigrationVersionRecord(t, ctx, client, 2, false) ensureMigrationHistoryRecord(t, ctx, client, 2, false) - if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } + if want, got := int64(1), migrationsOutput["000003.sql"].RowsAffected; want != got { + t.Errorf("want %d, but got %d", want, got) + } + // ensure that 000003.sql and 000004.sql have been applied. ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO") ensureMigrationVersionRecord(t, ctx, client, 4, false) @@ -499,7 +508,7 @@ func TestHotfixMigration(t *testing.T) { if err != nil { t.Fatalf("failed to load migrations: %v", err) } - if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if _, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } history, err := client.GetMigrationHistory(ctx, migrationTable) @@ -517,7 +526,7 @@ func TestHotfixMigration(t *testing.T) { if err != nil { t.Fatalf("failed to load migrations: %v", err) } - if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if _, err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } history, err = client.GetMigrationHistory(ctx, migrationTable) @@ -541,7 +550,7 @@ func TestUpgrade(t *testing.T) { if err != nil { t.Fatalf("failed to load migrations: %v", err) } - if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if _, err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } expected, err := client.GetMigrationHistory(ctx, migrationTable) @@ -559,7 +568,7 @@ func TestUpgrade(t *testing.T) { if client.tableExists(ctx, upgradeIndicator) == false { t.Error("upgrade indicator should exist") } - if err := client.UpgradeExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if _, err := client.UpgradeExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { t.Fatalf("failed to execute migration: %v", err) } @@ -769,11 +778,64 @@ func TestClient_RepairMigration(t *testing.T) { } } +func Test_MigrationInfoString(t *testing.T) { + tests := []struct { + testName string + migrationInfo MigrationsOutput + exptectedOutput string + }{ + { + testName: "no results", + migrationInfo: MigrationsOutput{}, + exptectedOutput: "", + }, + { + testName: "unitiated results - panic resiliant", + exptectedOutput: "", + }, + { + testName: "one result", + migrationInfo: MigrationsOutput{ + "i-deleted-everything.sql": migrationInfo{ + RowsAffected: 2000, + }, + }, + exptectedOutput: "Migration Information:\ni-deleted-everything.sql - rows affected: 2000\n", + }, + { + testName: "many results", + migrationInfo: MigrationsOutput{ + "0001-i-am-a-cool-update.sql": migrationInfo{ + RowsAffected: 20, + }, + "0002-not-as-cool-as-me.sql": migrationInfo{ + RowsAffected: 25, + }, + "0003-i-deleted-everything.sql": migrationInfo{ + RowsAffected: 2000, + }, + }, + exptectedOutput: "Migration Information:\n0001-i-am-a-cool-update.sql - rows affected: 20\n0002-not-as-cool-as-me.sql - rows affected: 25\n0003-i-deleted-everything.sql - rows affected: 2000\n", + }, + } + + for _, test := range tests { + output := test.migrationInfo.String() + assert.Equal(t, test.exptectedOutput, output) + } +} + func migrateUpDir(t *testing.T, ctx context.Context, client *Client, dir string, toSkip ...uint) error { t.Helper() migrations, err := LoadMigrations(dir, toSkip) if err != nil { return err } - return client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable) + + _, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable) + if err != nil { + return err + } + + return nil } diff --git a/pkg/spanner/migration.go b/pkg/spanner/migration.go index 352273a..d2d0b20 100644 --- a/pkg/spanner/migration.go +++ b/pkg/spanner/migration.go @@ -70,6 +70,9 @@ type ( // Name is the name of the migration Name string + // FileName is the name of the source file for the migration + FileName string + // Statements is the migration statements Statements []string @@ -138,6 +141,7 @@ func LoadMigrations(dir string, toSkipSlice []uint) (Migrations, error) { migrations = append(migrations, &Migration{ Version: uint(version), Name: matches[2], + FileName: f.Name(), Statements: statements, kind: kind, })