diff --git a/README.md b/README.md index 36fd2871..540ccc22 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ The following options are available with all commands. You must use command line - `--url, -u "protocol://host:port/dbname"` - specify the database url directly. _(env: `DATABASE_URL`)_ - `--env, -e "DATABASE_URL"` - specify an environment variable to read the database connection URL from. +- `--env-file ".env"` - specify an alternate environment variables file(s) to load. - `--migrations-dir, -d "./db/migrations"` - where to keep the migration files. _(env: `DBMATE_MIGRATIONS_DIR`)_ - `--migrations-table "schema_migrations"` - database table to record migrations in. _(env: `DBMATE_MIGRATIONS_TABLE`)_ - `--schema-file, -s "./db/schema.sql"` - a path to keep the schema.sql file. _(env: `DBMATE_SCHEMA_FILE`)_ diff --git a/fixtures/loadEnvFiles/.env b/fixtures/loadEnvFiles/.env new file mode 100644 index 00000000..11b4e5ca --- /dev/null +++ b/fixtures/loadEnvFiles/.env @@ -0,0 +1 @@ +TEST_DOTENV=default diff --git a/fixtures/loadEnvFiles/.gitignore b/fixtures/loadEnvFiles/.gitignore new file mode 100644 index 00000000..8e0776e8 --- /dev/null +++ b/fixtures/loadEnvFiles/.gitignore @@ -0,0 +1 @@ +!.env diff --git a/fixtures/loadEnvFiles/first.txt b/fixtures/loadEnvFiles/first.txt new file mode 100644 index 00000000..71c5e94e --- /dev/null +++ b/fixtures/loadEnvFiles/first.txt @@ -0,0 +1 @@ +FIRST=one diff --git a/fixtures/loadEnvFiles/invalid.txt b/fixtures/loadEnvFiles/invalid.txt new file mode 100644 index 00000000..7524c931 --- /dev/null +++ b/fixtures/loadEnvFiles/invalid.txt @@ -0,0 +1 @@ +INVALID ENV FILE diff --git a/fixtures/loadEnvFiles/second.txt b/fixtures/loadEnvFiles/second.txt new file mode 100644 index 00000000..0e1c837b --- /dev/null +++ b/fixtures/loadEnvFiles/second.txt @@ -0,0 +1 @@ +SECOND=two diff --git a/main.go b/main.go index 83b85bf8..160123b3 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,8 @@ package main import ( + "errors" "fmt" - "log" "net/url" "os" "regexp" @@ -17,10 +17,14 @@ import ( ) func main() { - loadDotEnv() + err := loadEnvFiles(os.Args[1:]) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(3) + } app := NewApp() - err := app.Run(os.Args) + err = app.Run(os.Args) if err != nil { errText := redactLogString(fmt.Sprintf("Error: %s\n", err)) @@ -37,6 +41,7 @@ func NewApp() *cli.App { app.Version = dbmate.Version defaultDB := dbmate.New(nil) + app.Flags = []cli.Flag{ &cli.StringFlag{ Name: "url", @@ -49,6 +54,11 @@ func NewApp() *cli.App { Value: "DATABASE_URL", Usage: "specify an environment variable containing the database URL", }, + &cli.StringSliceFlag{ + Name: "env-file", + Value: cli.NewStringSlice(".env"), + Usage: "specify a file to load environment variables from", + }, &cli.StringSliceFlag{ Name: "migrations-dir", Aliases: []string{"d"}, @@ -211,15 +221,46 @@ func NewApp() *cli.App { return app } -// load environment variables from .env file -func loadDotEnv() { - if _, err := os.Stat(".env"); err != nil { - return +// load environment variables from file(s) +func loadEnvFiles(args []string) error { + var envFiles []string + + for i := 0; i < len(args); i++ { + if args[i] == "--env-file" { + if i+1 >= len(args) { + // returning nil here, even though it's an error + // because we want the caller to proceed anyway, + // and produce the actual arg parsing error response + return nil + } + + envFiles = append(envFiles, args[i+1]) + i++ + } + } + + if len(envFiles) == 0 { + envFiles = []string{".env"} } - if err := godotenv.Load(); err != nil { - log.Fatalf("Error loading .env file: %s", err.Error()) + // try to load all files in sequential order, + // ignoring any that do not exist + for _, file := range envFiles { + err := godotenv.Load([]string{file}...) + if err == nil { + continue + } + + var perr *os.PathError + if errors.As(err, &perr) && errors.Is(perr, os.ErrNotExist) { + // Ignoring file not found error + continue + } + + return fmt.Errorf("loading env file(s) %v: %v", envFiles, err) } + + return nil } // action wraps a cli.ActionFunc with dbmate initialization logic diff --git a/main_test.go b/main_test.go index df1cab52..5d184978 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package main import ( "flag" "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -58,3 +59,122 @@ func TestRedactLogString(t *testing.T) { require.Equal(t, ex.expected, redactLogString(ex.in)) } } + +func TestLoadEnvFiles(t *testing.T) { + setup := func(t *testing.T) { + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + + env := os.Environ() + os.Clearenv() + + err = os.Chdir("fixtures/loadEnvFiles") + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + err := os.Chdir(cwd) + if err != nil { + t.Fatal(err) + } + + os.Clearenv() + + for _, e := range env { + pair := strings.SplitN(e, "=", 2) + os.Setenv(pair[0], pair[1]) + } + }) + } + + t.Run("default file is .env", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{}) + require.NoError(t, err) + + require.Equal(t, 1, len(os.Environ())) + require.Equal(t, "default", os.Getenv("TEST_DOTENV")) + }) + + t.Run("valid file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "first.txt"}) + require.NoError(t, err) + require.Equal(t, 1, len(os.Environ())) + require.Equal(t, "one", os.Getenv("FIRST")) + }) + + t.Run("two valid files", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "first.txt", "--env-file", "second.txt"}) + require.NoError(t, err) + require.Equal(t, 2, len(os.Environ())) + require.Equal(t, "one", os.Getenv("FIRST")) + require.Equal(t, "two", os.Getenv("SECOND")) + }) + + t.Run("nonexistent file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "nonexistent.txt"}) + require.NoError(t, err) + require.Equal(t, 0, len(os.Environ())) + }) + + t.Run("no overload", func(t *testing.T) { + setup(t) + + // we do not load values over existing values + os.Setenv("FIRST", "not one") + + err := loadEnvFiles([]string{"--env-file", "first.txt"}) + require.NoError(t, err) + require.Equal(t, 1, len(os.Environ())) + require.Equal(t, "not one", os.Getenv("FIRST")) + }) + + t.Run("invalid file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "invalid.txt"}) + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected character \"\\n\" in variable name near \"INVALID ENV FILE\\n\"") + require.Equal(t, 0, len(os.Environ())) + }) + + t.Run("invalid file followed by a valid file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "invalid.txt", "--env-file", "first.txt"}) + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected character \"\\n\" in variable name near \"INVALID ENV FILE\\n\"") + require.Equal(t, 0, len(os.Environ())) + }) + + t.Run("valid file followed by an invalid file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "first.txt", "--env-file", "invalid.txt"}) + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected character \"\\n\" in variable name near \"INVALID ENV FILE\\n\"") + require.Equal(t, 1, len(os.Environ())) + require.Equal(t, "one", os.Getenv("FIRST")) + }) + + t.Run("valid file followed by an invalid file followed by a valid file", func(t *testing.T) { + setup(t) + + err := loadEnvFiles([]string{"--env-file", "first.txt", "--env-file", "invalid.txt", "--env-file", "second.txt"}) + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected character \"\\n\" in variable name near \"INVALID ENV FILE\\n\"") + // files after an invalid file should not get loaded + require.Equal(t, 1, len(os.Environ())) + require.Equal(t, "one", os.Getenv("FIRST")) + }) +}