Skip to content

Commit

Permalink
Add flag to specify jobs on postgres db load
Browse files Browse the repository at this point in the history
  • Loading branch information
ipmb committed Mar 10, 2022
1 parent d723837 commit 3d9fb8d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
15 changes: 10 additions & 5 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,16 @@ func (a *App) StartTask(taskFamily *string, command []string, taskOverride *ecs.
startedBy := fmt.Sprintf("apppack-cli/shell/%s", *email)
runTaskArgs.TaskDefinition = taskDefn.TaskDefinition.TaskDefinitionArn
runTaskArgs.StartedBy = &startedBy
taskOverride.ContainerOverrides = []*ecs.ContainerOverride{
{
Name: taskDefn.TaskDefinition.ContainerDefinitions[0].Name,
Command: cmd,
},
if len(taskOverride.ContainerOverrides) == 1 {
taskOverride.ContainerOverrides[0].Name = taskDefn.TaskDefinition.ContainerDefinitions[0].Name
taskOverride.ContainerOverrides[0].Command = cmd
} else {
taskOverride.ContainerOverrides = []*ecs.ContainerOverride{
{
Name: taskDefn.TaskDefinition.ContainerDefinitions[0].Name,
Command: cmd,
},
}
}
runTaskArgs.Overrides = taskOverride
ecsTaskOutput, err := ecsSvc.RunTask(&runTaskArgs)
Expand Down
39 changes: 37 additions & 2 deletions cmd/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"

Expand All @@ -32,6 +33,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/logrusorgru/aurora"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

var dbOutputFile string
Expand Down Expand Up @@ -162,6 +164,18 @@ func taskLogs(sess *session.Session, task *ecs.Task) error {
return nil
}

var postgresLoadJobs int

func flagIsSet(flags *pflag.FlagSet, name string) bool {
found := false
flags.Visit(func(flag *pflag.Flag) {
if flag.Name == name {
found = true
}
})
return found
}

// dbLoadCmd represents the db load command
var dbLoadCmd = &cobra.Command{
Use: "load <dumpfile>",
Expand All @@ -178,6 +192,15 @@ WARNING: This is a destructive action which will delete the contents of your rem
// db dump load can be really slow, let people open longer sessions to wait for it to finish
app, err := app.Init(AppName, UseAWSCredentials, MaxSessionDurationSeconds)
checkErr(err)
checkErr(app.LoadSettings())
isPostgres := strings.Contains(app.Settings.DBUtils.Engine, "postgres")
// exit if we're not using postgres and --jobs is set
if !isPostgres && flagIsSet(cmd.Flags(), "jobs") {
checkErr(fmt.Errorf("the --jobs/-j flag is only supported for Postgres databases"))
}
if postgresLoadJobs < 1 {
checkErr(fmt.Errorf("the --jobs/-j flag must be set to a positive integer"))
}
family, err := app.DBDumpLoadFamily()
checkErr(err)
ui.Spinner.Stop()
Expand All @@ -199,10 +222,21 @@ WARNING: This is a destructive action which will delete the contents of your rem
})
checkErr(err)
}
taskOverride := &ecs.TaskOverride{}
if isPostgres {
taskOverride.ContainerOverrides = []*ecs.ContainerOverride{
{
Name: aws.String("app"),
Environment: []*ecs.KeyValuePair{
{Name: aws.String("PG_RESTORE_JOBS"), Value: aws.String(strconv.Itoa(postgresLoadJobs))},
},
},
}
}
task, err := app.StartTask(
family,
[]string{"load-from-s3.sh", remoteFile},
&ecs.TaskOverride{},
taskOverride,
true,
)
ui.Spinner.Stop()
Expand All @@ -214,7 +248,7 @@ WARNING: This is a destructive action which will delete the contents of your rem
checkErr(err)
ui.Spinner.Stop()
// pg_restore can have inconsequential errors... don't assume failure, but notify user
if *exitCode != 0 && strings.Contains(app.Settings.DBUtils.Engine, "postgres") {
if *exitCode != 0 && isPostgres {
taskLogs(app.Session, task)
printWarning("check pg_restore output")
} else if *exitCode != 0 {
Expand All @@ -236,4 +270,5 @@ func init() {
dbCmd.AddCommand(dbDumpCmd)
dbDumpCmd.Flags().StringVarP(&dbOutputFile, "output", "o", "", "path to output file -- default will be <app-name> with the appropriate extension for the database")
dbCmd.AddCommand(dbLoadCmd)
dbLoadCmd.Flags().IntVarP(&postgresLoadJobs, "jobs", "j", 2, "number of jobs to use for the load (passed through as --jobs to pg_restore -- Postgres only)")
}

0 comments on commit 3d9fb8d

Please sign in to comment.