Skip to content

Commit

Permalink
feat: Update migrations postgres URI parsing and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Fernando Barillas authored and acaloiaro committed Mar 5, 2024
1 parent 0c32209 commit 672e101
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 24 deletions.
80 changes: 56 additions & 24 deletions backends/postgres/postgres_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"net/url"
"os"
"strings"
"sync"
"time"

Expand All @@ -34,6 +33,9 @@ import (
var migrationsFS embed.FS

const (
queryParamSSLMode = "sslmode"
queryParamMigrationsTable = "x-migrations-table"

JobQuery = `SELECT id,fingerprint,queue,status,deadline,payload,retries,max_retries,run_after,ran_at,created_at,error
FROM neoq_jobs
WHERE id = $1
Expand Down Expand Up @@ -324,6 +326,7 @@ func txFromContext(ctx context.Context) (t pgx.Tx, err error) {
func (p *PgBackend) initializeDB() (err error) {
migrations, err := iofs.New(migrationsFS, "migrations")
if err != nil {
err = fmt.Errorf("unable to run migrations, error during iofs new: %w", err)
p.logger.Error("unable to run migrations", slog.Any("error", err))
return
}
Expand All @@ -332,36 +335,16 @@ func (p *PgBackend) initializeDB() (err error) {
// it with pgx-specific config params like `max_conn_count`. However, `go-migrate` uses `pq` under the hood, and
// these `pgx` config params cause `pq` to throw an "unknown config parameter" error when they're encountered.
// So we must first sanitize connection strings for pq
var pgxCfg *pgx.ConnConfig
pgxCfg, err = pgx.ParseConfig(p.config.ConnectionString)
pqConnectionString, err := GetPQConnectionString(p.config.ConnectionString)
if err != nil {
err = fmt.Errorf("unable to run migrations, error parsing connection string: %w", err)
p.logger.Error("unable to run migrations", slog.Any("error", err))
return
}

// nil TLSConfig means "sslmode=disable" was set on the connection
sslMode := "verify-ca"
if pgxCfg.TLSConfig == nil {
sslMode = "disable"
} else if pgxCfg.TLSConfig.InsecureSkipVerify {
sslMode = "require"
}
if dbURL, err := url.Parse(pgxCfg.ConnString()); err == nil &&
strings.HasPrefix(dbURL.Scheme, "postgres") {
val := dbURL.Query()
if v := val.Get("sslmode"); v != "" {
sslMode = v // set sslmode from existing connection string
}
}

pqConnectionString := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s&x-migrations-table=neoq_schema_migrations",
pgxCfg.User,
url.QueryEscape(pgxCfg.Password),
pgxCfg.Host,
pgxCfg.Database,
sslMode)
m, err := migrate.NewWithSourceInstance("iofs", migrations, pqConnectionString)
if err != nil {
err = fmt.Errorf("unable to run migrations, could not create new source: %w", err)
p.logger.Error("unable to run migrations", slog.Any("error", err))
return
}
Expand All @@ -370,6 +353,7 @@ func (p *PgBackend) initializeDB() (err error) {

err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
err = fmt.Errorf("unable to run migrations, could not apply up migration: %w", err)
p.logger.Error("unable to run migrations", slog.Any("error", err))
return
}
Expand Down Expand Up @@ -1030,3 +1014,51 @@ func (p *PgBackend) acquire(ctx context.Context) (conn *pgxpool.Conn, err error)
func withJobContext(ctx context.Context, j *jobs.Job) context.Context {
return context.WithValue(ctx, internal.JobCtxVarKey, j)
}

func GetPQConnectionString(connectionString string) (string, error) {
pgxCfg, err := pgx.ParseConfig(connectionString)
if err != nil {
return "", fmt.Errorf("unable to parse connection string %s: %w", connectionString, err)
}

dbURI, err := url.Parse(pgxCfg.ConnString())
if err != nil {
return "", fmt.Errorf("unable to parse connection string %s: %w", connectionString, err)
}

if dbURI.String() == "" {
return "", fmt.Errorf("connection string cannot be empty")
}

scheme := dbURI.Scheme
if scheme == "" {
// This is probably a pq-style string, return it as-is
return connectionString, nil
}

if scheme != "postgres" && scheme != "postgresql" {
// This isn't a postgresql URI-style string (postgres://hostname/db)
return "", fmt.Errorf("only postgres and postgresql scheme URIs are supported, invalid connection string: %s", connectionString)
}

sslMode := "verify-ca"
if pgxCfg.TLSConfig == nil {
sslMode = "disable"
} else if pgxCfg.TLSConfig.InsecureSkipVerify {
sslMode = "require"
}

// Prefer original sslmode if it was set
originalSSLMode := dbURI.Query().Get(queryParamSSLMode)
if originalSSLMode != "" {
sslMode = originalSSLMode
}

// Clear out original query, use only query params that are pq compatible
query := url.Values{}
query.Set(queryParamSSLMode, sslMode)
query.Set(queryParamMigrationsTable, "neoq_schema_migrations")
dbURI.RawQuery = query.Encode()

return dbURI.String(), nil
}
174 changes: 174 additions & 0 deletions backends/postgres/postgres_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/url"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -825,3 +826,176 @@ func TestFutureJobProcessing(t *testing.T) {
t.Error("job ran before RunAfter")
}
}

func TestGetPQConnectionString(t *testing.T) {
tests := []struct {
name string
input string
want string
wantErr bool
}{
{
name: "standard input",
input: "postgres://username:password@hostname:5432/database",
want: "postgres://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "standard input with postgresql scheme",
input: "postgresql://username:password@hostname:5432/database",
want: "postgresql://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "no port number",
input: "postgres://username:password@hostname/database",
want: "postgres://username:password@hostname/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom port number",
input: "postgres://username:password@hostname:1234/database",
want: "postgres://username:password@hostname:1234/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=disable",
input: "postgres://username:password@hostname:5432/database?sslmode=disable",
want: "postgres://username:password@hostname:5432/database?sslmode=disable&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=allow",
input: "postgres://username:password@hostname:5432/database?sslmode=allow",
want: "postgres://username:password@hostname:5432/database?sslmode=allow&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=prefer",
input: "postgres://username:password@hostname:5432/database?sslmode=prefer",
want: "postgres://username:password@hostname:5432/database?sslmode=prefer&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=require",
input: "postgres://username:password@hostname:5432/database?sslmode=require",
want: "postgres://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=verify-ca",
input: "postgres://username:password@hostname:5432/database?sslmode=verify-ca",
want: "postgres://username:password@hostname:5432/database?sslmode=verify-ca&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom sslmode=verify-full",
input: "postgres://username:password@hostname:5432/database?sslmode=verify-full",
want: "postgres://username:password@hostname:5432/database?sslmode=verify-full&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "encoded password is preserved",
input: "postgres://username:pass%21%40%23$%25%5E&%2A%28%29%3A%2F%3Fword@hostname:5432/database",
want: fmt.Sprintf(
"postgres://%s@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
url.UserPassword("username", "pass!@#$%^&*():/?word").String(),
),
wantErr: false,
},
{
name: "multiple hostnames",
input: "postgres://username:password@hostname1,hostname2,hostname3:5432/database",
want: "postgres://username:password@hostname1,hostname2,hostname3:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},

// Examples connstrings from https://www.postgresql.org/docs/16/libpq-connect.html
{
name: "valid empty postgresql scheme input",
input: "postgresql://",
want: "postgresql:?sslmode=disable&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "hostname localhost",
input: "postgresql://localhost",
want: "postgresql://localhost?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "hostname localhost with custom port",
input: "postgresql://localhost:5433",
want: "postgresql://localhost:5433?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "non-default database",
input: "postgresql://localhost/mydb",
want: "postgresql://localhost/mydb?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "username",
input: "postgresql://user@localhost",
want: "postgresql://user@localhost?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "username and password",
input: "postgresql://user:secret@localhost",
want: "postgresql://user:secret@localhost?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "custom params are ignored",
input: "postgresql://other@localhost/otherdb?connect_timeout=10&application_name=myapp",
want: "postgresql://other@localhost/otherdb?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "multiple hostnames and ports",
input: "postgresql://host1:123,host2:456/somedb?target_session_attrs=any&application_name=myapp",
want: "postgresql://host1:123,host2:456/somedb?sslmode=require&x-migrations-table=neoq_schema_migrations",
wantErr: false,
},
{
name: "pq-style input is returned as-is",
input: "host=localhost port=5432 dbname=mydb connect_timeout=10",
want: "host=localhost port=5432 dbname=mydb connect_timeout=10",
wantErr: false,
},

// Inputs that cause errors
{
name: "non-postgres scheme returns error",
input: "https://user:password@example.com:443/path?query=true",
want: "",
wantErr: true,
},
{
name: "empty input returns error",
input: "",
want: "",
wantErr: true,
},
{
name: "custom bad sslmode=foo returns error",
input: "postgres://username:password@hostname:1234/database?sslmode=foo",
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := postgres.GetPQConnectionString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("GetPQConnectionString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetPQConnectionString()\ngot = %v\nwant = %v", got, tt.want)
}
})
}
}

0 comments on commit 672e101

Please sign in to comment.