Skip to content

Commit

Permalink
Clean up test helper functions (#544)
Browse files Browse the repository at this point in the history
Reduce duplication in tests and consistent messages for skipped
integration tests.
  • Loading branch information
amacneil committed Apr 21, 2024
1 parent 92fbde5 commit 88486cb
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 151 deletions.
53 changes: 25 additions & 28 deletions pkg/dbmate/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"
_ "github.com/amacneil/dbmate/v2/pkg/driver/mysql"
_ "github.com/amacneil/dbmate/v2/pkg/driver/postgres"
Expand All @@ -21,12 +22,12 @@ import (

var rootDir string

func sqliteTestURL() *url.URL {
return dbutil.MustParseURL("sqlite:dbmate_test.sqlite3")
func sqliteTestURL(t *testing.T) *url.URL {
return dbtest.MustParseURL(t, "sqlite:dbmate_test.sqlite3")
}

func sqliteBrokenTestURL() *url.URL {
return dbutil.MustParseURL("sqlite:/doesnotexist/dbmate_test.sqlite3")
func sqliteBrokenTestURL(t *testing.T) *url.URL {
return dbtest.MustParseURL(t, "sqlite:/doesnotexist/dbmate_test.sqlite3")
}

func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
Expand All @@ -48,7 +49,7 @@ func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
}

func TestNew(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("foo:test"))
db := dbmate.New(dbtest.MustParseURL(t, "foo:test"))
require.True(t, db.AutoDumpSchema)
require.Equal(t, "foo:test", db.DatabaseURL.String())
require.Equal(t, []string{"./db/migrations"}, db.MigrationsDir)
Expand All @@ -68,22 +69,22 @@ func TestGetDriver(t *testing.T) {
})

t.Run("missing schema", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("//hi"))
db := dbmate.New(dbtest.MustParseURL(t, "//hi"))
drv, err := db.Driver()
require.Nil(t, drv)
require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
})

t.Run("invalid driver", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("foo://bar"))
db := dbmate.New(dbtest.MustParseURL(t, "foo://bar"))
drv, err := db.Driver()
require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv)
})
}

func TestWait(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

// speed up retry loop for testing
db.WaitInterval = time.Millisecond
Expand All @@ -95,7 +96,7 @@ func TestWait(t *testing.T) {
})

t.Run("invalid connection", func(t *testing.T) {
db.DatabaseURL = sqliteBrokenTestURL()
db.DatabaseURL = sqliteBrokenTestURL(t)

err := db.Wait()
require.Error(t, err)
Expand All @@ -104,7 +105,7 @@ func TestWait(t *testing.T) {
}

func TestDumpSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

// create custom schema file directory
dir, err := os.MkdirTemp("", "dbmate")
Expand Down Expand Up @@ -137,7 +138,7 @@ func TestDumpSchema(t *testing.T) {
}

func TestAutoDumpSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.AutoDumpSchema = true

// create custom schema file directory
Expand Down Expand Up @@ -180,7 +181,7 @@ func TestAutoDumpSchema(t *testing.T) {
}

func TestLoadSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
drv, err := db.Driver()
require.NoError(t, err)

Expand Down Expand Up @@ -245,7 +246,7 @@ func TestLoadSchema(t *testing.T) {

func checkWaitCalled(t *testing.T, db *dbmate.DB, command func() error) {
oldDatabaseURL := db.DatabaseURL
db.DatabaseURL = sqliteBrokenTestURL()
db.DatabaseURL = sqliteBrokenTestURL(t)

err := command()
require.Error(t, err)
Expand All @@ -255,7 +256,7 @@ func checkWaitCalled(t *testing.T, db *dbmate.DB, command func() error) {
}

func testWaitBefore(t *testing.T, verbose bool) {
u := sqliteTestURL()
u := sqliteTestURL(t)
db := newTestDB(t, u)
db.Verbose = verbose
db.WaitBefore = true
Expand Down Expand Up @@ -329,20 +330,16 @@ Rows affected: 0`)

func testEachURL(t *testing.T, fn func(*testing.T, *url.URL)) {
t.Run("sqlite", func(t *testing.T) {
fn(t, sqliteTestURL())
fn(t, sqliteTestURL(t))
})

optionalTestURLs := []string{"MYSQL_TEST_URL", "POSTGRES_TEST_URL"}
for _, varname := range optionalTestURLs {
// split on underscore and take first part
testname := strings.ToLower(strings.Split(varname, "_")[0])
t.Run(testname, func(t *testing.T) {
val := os.Getenv(varname)
if val == "" {
t.Skipf("no %s url provided", varname)
} else {
fn(t, dbutil.MustParseURL(val))
}
u := dbtest.GetenvURLOrSkip(t, varname)
fn(t, u)
})
}
}
Expand Down Expand Up @@ -543,7 +540,7 @@ func TestFindMigrations(t *testing.T) {

func TestFindMigrationsAbsolute(t *testing.T) {
t.Run("relative path", func(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.MigrationsDir = []string{"db/migrations"}

migrations, err := db.FindMigrations()
Expand All @@ -562,7 +559,7 @@ func TestFindMigrationsAbsolute(t *testing.T) {
require.NoError(t, err)
defer file.Close()

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.MigrationsDir = []string{dir}
require.Nil(t, db.FS)

Expand Down Expand Up @@ -594,7 +591,7 @@ drop table users;
"db/not_migrations/20151129054053_test_migration.sql": {},
}

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.FS = mapFS

// drop and recreate database
Expand Down Expand Up @@ -641,7 +638,7 @@ func TestFindMigrationsFSMultipleDirs(t *testing.T) {
"db/migrations_c/006_test_migration_c.sql": {},
}

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.FS = mapFS
db.MigrationsDir = []string{"./db/migrations_a", "./db/migrations_b", "./db/migrations_c"}

Expand All @@ -667,7 +664,7 @@ func TestMigrateUnrestrictedOrder(t *testing.T) {
emptyMigration := []byte("-- migrate:up\n-- migrate:down")

// initialize database
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

err := db.Drop()
require.NoError(t, err)
Expand Down Expand Up @@ -698,7 +695,7 @@ func TestMigrateStrictOrder(t *testing.T) {
emptyMigration := []byte("-- migrate:up\n-- migrate:down")

// initialize database
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.Strict = true

err := db.Drop()
Expand Down Expand Up @@ -738,7 +735,7 @@ func TestMigrateStrictOrder(t *testing.T) {
}

func TestMigrateQueryErrorMessage(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

err := db.Drop()
require.NoError(t, err)
Expand Down
36 changes: 36 additions & 0 deletions pkg/dbtest/dbtest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Helper package that should only be used in test files
package dbtest

import (
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"
)

// MustParseURL parses a URL from string, and fails the test if the URL is invalid.
func MustParseURL(t *testing.T, s string) *url.URL {
require.NotEmpty(t, s)

u, err := url.Parse(s)
require.NoError(t, err)

return u
}

// GetenvOrSkip gets an environment variable, and skips the test if it is empty.
func GetenvOrSkip(t *testing.T, key string) string {
value := os.Getenv(key)
if value == "" {
t.Skipf("no %s provided", key)
}

return value
}

// GetenvURLOrSkip gets an environment variable, parses it as a URL,
// fails the test if the URL is invalid, and skips the test if empty.
func GetenvURLOrSkip(t *testing.T, key string) *url.URL {
return MustParseURL(t, GetenvOrSkip(t, key))
}
15 changes: 0 additions & 15 deletions pkg/dbutil/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,6 @@ func QueryValue(db Transaction, query string, args ...interface{}) (string, erro
return result.String, nil
}

// MustParseURL parses a URL from string, and panics if it fails.
// It is used during testing and in cases where we are parsing a generated URL.
func MustParseURL(s string) *url.URL {
if s == "" {
panic("missing url")
}

u, err := url.Parse(s)
if err != nil {
panic(err)
}

return u
}

// MustUnescapePath unescapes a URL path, and panics if it fails.
// It is used during in cases where we are parsing a generated path.
func MustUnescapePath(s string) string {
Expand Down
5 changes: 3 additions & 2 deletions pkg/dbutil/dbutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"testing"

"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"

_ "github.com/mattn/go-sqlite3" // database/sql driver
Expand All @@ -12,13 +13,13 @@ import (

func TestDatabaseName(t *testing.T) {
t.Run("valid", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host/dbname?query")
u := dbtest.MustParseURL(t, "foo://host/dbname?query")
name := dbutil.DatabaseName(u)
require.Equal(t, "dbname", name)
})

t.Run("empty", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host")
u := dbtest.MustParseURL(t, "foo://host")
name := dbutil.DatabaseName(u)
require.Equal(t, "", name)
})
Expand Down
17 changes: 4 additions & 13 deletions pkg/driver/bigquery/bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,25 @@ import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"
)

func testBigQueryDriver(t *testing.T) *Driver {
url := os.Getenv("BIGQUERY_TEST_URL")
if url == "" {
t.Skip("no BIGQUERY_TEST_URL provided")
}

u := dbutil.MustParseURL(url)
u := dbtest.GetenvURLOrSkip(t, "BIGQUERY_TEST_URL")
drv, err := dbmate.New(u).Driver()
require.NoError(t, err)

return drv.(*Driver)
}

func testGoogleBigQueryDriver(t *testing.T) *Driver {
testURL := os.Getenv("GOOGLE_BIGQUERY_TEST_URL")
if testURL == "" {
t.Skip("no GOOGLE_BIGQUERY_TEST_URL provided")
}
u := dbutil.MustParseURL(testURL)
u := dbtest.GetenvURLOrSkip(t, "GOOGLE_BIGQUERY_TEST_URL")

endpoint := u.Query().Get("endpoint")
if endpoint != "" {
Expand Down Expand Up @@ -86,7 +77,7 @@ func prepTestGoogleBigQueryDB(t *testing.T) *sql.DB {
}

func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("bigquery://"))
db := dbmate.New(dbtest.MustParseURL(t, "bigquery://"))
drvInterface, err := db.Driver()
require.NoError(t, err)

Expand Down
9 changes: 7 additions & 2 deletions pkg/driver/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewDriver(config dbmate.DriverConfig) dbmate.Driver {

func connectionString(initialURL *url.URL) string {
// clone url
u := dbutil.MustParseURL(initialURL.String())
u, _ := url.Parse(initialURL.String())

host := u.Host
if u.Port() == "" {
Expand Down Expand Up @@ -109,7 +109,12 @@ func (drv *Driver) onClusterClause() string {
}

func (drv *Driver) databaseName() string {
name := strings.TrimLeft(dbutil.MustParseURL(connectionString(drv.databaseURL)).Path, "/")
u, err := url.Parse(connectionString(drv.databaseURL))
if err != nil {
panic(err)
}

name := strings.TrimLeft(u.Path, "/")
if name == "" {
name = "default"
}
Expand Down

0 comments on commit 88486cb

Please sign in to comment.