Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up test helper functions #544

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions pkg/dbmate/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"
_ "github.com/amacneil/dbmate/v2/pkg/driver/mysql"
_ "github.com/amacneil/dbmate/v2/pkg/driver/postgres"
Expand All @@ -21,12 +22,12 @@ import (

var rootDir string

func sqliteTestURL() *url.URL {
return dbutil.MustParseURL("sqlite:dbmate_test.sqlite3")
func sqliteTestURL(t *testing.T) *url.URL {
return dbtest.MustParseURL(t, "sqlite:dbmate_test.sqlite3")
}

func sqliteBrokenTestURL() *url.URL {
return dbutil.MustParseURL("sqlite:/doesnotexist/dbmate_test.sqlite3")
func sqliteBrokenTestURL(t *testing.T) *url.URL {
return dbtest.MustParseURL(t, "sqlite:/doesnotexist/dbmate_test.sqlite3")
}

func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
Expand All @@ -48,7 +49,7 @@ func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
}

func TestNew(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("foo:test"))
db := dbmate.New(dbtest.MustParseURL(t, "foo:test"))
require.True(t, db.AutoDumpSchema)
require.Equal(t, "foo:test", db.DatabaseURL.String())
require.Equal(t, []string{"./db/migrations"}, db.MigrationsDir)
Expand All @@ -68,22 +69,22 @@ func TestGetDriver(t *testing.T) {
})

t.Run("missing schema", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("//hi"))
db := dbmate.New(dbtest.MustParseURL(t, "//hi"))
drv, err := db.Driver()
require.Nil(t, drv)
require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
})

t.Run("invalid driver", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("foo://bar"))
db := dbmate.New(dbtest.MustParseURL(t, "foo://bar"))
drv, err := db.Driver()
require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv)
})
}

func TestWait(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

// speed up retry loop for testing
db.WaitInterval = time.Millisecond
Expand All @@ -95,7 +96,7 @@ func TestWait(t *testing.T) {
})

t.Run("invalid connection", func(t *testing.T) {
db.DatabaseURL = sqliteBrokenTestURL()
db.DatabaseURL = sqliteBrokenTestURL(t)

err := db.Wait()
require.Error(t, err)
Expand All @@ -104,7 +105,7 @@ func TestWait(t *testing.T) {
}

func TestDumpSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

// create custom schema file directory
dir, err := os.MkdirTemp("", "dbmate")
Expand Down Expand Up @@ -137,7 +138,7 @@ func TestDumpSchema(t *testing.T) {
}

func TestAutoDumpSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.AutoDumpSchema = true

// create custom schema file directory
Expand Down Expand Up @@ -180,7 +181,7 @@ func TestAutoDumpSchema(t *testing.T) {
}

func TestLoadSchema(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
drv, err := db.Driver()
require.NoError(t, err)

Expand Down Expand Up @@ -245,7 +246,7 @@ func TestLoadSchema(t *testing.T) {

func checkWaitCalled(t *testing.T, db *dbmate.DB, command func() error) {
oldDatabaseURL := db.DatabaseURL
db.DatabaseURL = sqliteBrokenTestURL()
db.DatabaseURL = sqliteBrokenTestURL(t)

err := command()
require.Error(t, err)
Expand All @@ -255,7 +256,7 @@ func checkWaitCalled(t *testing.T, db *dbmate.DB, command func() error) {
}

func testWaitBefore(t *testing.T, verbose bool) {
u := sqliteTestURL()
u := sqliteTestURL(t)
db := newTestDB(t, u)
db.Verbose = verbose
db.WaitBefore = true
Expand Down Expand Up @@ -329,20 +330,16 @@ Rows affected: 0`)

func testEachURL(t *testing.T, fn func(*testing.T, *url.URL)) {
t.Run("sqlite", func(t *testing.T) {
fn(t, sqliteTestURL())
fn(t, sqliteTestURL(t))
})

optionalTestURLs := []string{"MYSQL_TEST_URL", "POSTGRES_TEST_URL"}
for _, varname := range optionalTestURLs {
// split on underscore and take first part
testname := strings.ToLower(strings.Split(varname, "_")[0])
t.Run(testname, func(t *testing.T) {
val := os.Getenv(varname)
if val == "" {
t.Skipf("no %s url provided", varname)
} else {
fn(t, dbutil.MustParseURL(val))
}
u := dbtest.GetenvURLOrSkip(t, varname)
fn(t, u)
})
}
}
Expand Down Expand Up @@ -543,7 +540,7 @@ func TestFindMigrations(t *testing.T) {

func TestFindMigrationsAbsolute(t *testing.T) {
t.Run("relative path", func(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.MigrationsDir = []string{"db/migrations"}

migrations, err := db.FindMigrations()
Expand All @@ -562,7 +559,7 @@ func TestFindMigrationsAbsolute(t *testing.T) {
require.NoError(t, err)
defer file.Close()

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.MigrationsDir = []string{dir}
require.Nil(t, db.FS)

Expand Down Expand Up @@ -594,7 +591,7 @@ drop table users;
"db/not_migrations/20151129054053_test_migration.sql": {},
}

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.FS = mapFS

// drop and recreate database
Expand Down Expand Up @@ -641,7 +638,7 @@ func TestFindMigrationsFSMultipleDirs(t *testing.T) {
"db/migrations_c/006_test_migration_c.sql": {},
}

db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.FS = mapFS
db.MigrationsDir = []string{"./db/migrations_a", "./db/migrations_b", "./db/migrations_c"}

Expand All @@ -667,7 +664,7 @@ func TestMigrateUnrestrictedOrder(t *testing.T) {
emptyMigration := []byte("-- migrate:up\n-- migrate:down")

// initialize database
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

err := db.Drop()
require.NoError(t, err)
Expand Down Expand Up @@ -698,7 +695,7 @@ func TestMigrateStrictOrder(t *testing.T) {
emptyMigration := []byte("-- migrate:up\n-- migrate:down")

// initialize database
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))
db.Strict = true

err := db.Drop()
Expand Down Expand Up @@ -738,7 +735,7 @@ func TestMigrateStrictOrder(t *testing.T) {
}

func TestMigrateQueryErrorMessage(t *testing.T) {
db := newTestDB(t, sqliteTestURL())
db := newTestDB(t, sqliteTestURL(t))

err := db.Drop()
require.NoError(t, err)
Expand Down
36 changes: 36 additions & 0 deletions pkg/dbtest/dbtest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Helper package that should only be used in test files
package dbtest

import (
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"
)

// MustParseURL parses a URL from string, and fails the test if the URL is invalid.
func MustParseURL(t *testing.T, s string) *url.URL {
require.NotEmpty(t, s)

u, err := url.Parse(s)
require.NoError(t, err)

return u
}

// GetenvOrSkip gets an environment variable, and skips the test if it is empty.
func GetenvOrSkip(t *testing.T, key string) string {
value := os.Getenv(key)
if value == "" {
t.Skipf("no %s provided", key)
}

return value
}

// GetenvURLOrSkip gets an environment variable, parses it as a URL,
// fails the test if the URL is invalid, and skips the test if empty.
func GetenvURLOrSkip(t *testing.T, key string) *url.URL {
return MustParseURL(t, GetenvOrSkip(t, key))
}
15 changes: 0 additions & 15 deletions pkg/dbutil/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,6 @@ func QueryValue(db Transaction, query string, args ...interface{}) (string, erro
return result.String, nil
}

// MustParseURL parses a URL from string, and panics if it fails.
// It is used during testing and in cases where we are parsing a generated URL.
func MustParseURL(s string) *url.URL {
if s == "" {
panic("missing url")
}

u, err := url.Parse(s)
if err != nil {
panic(err)
}

return u
}

// MustUnescapePath unescapes a URL path, and panics if it fails.
// It is used during in cases where we are parsing a generated path.
func MustUnescapePath(s string) string {
Expand Down
5 changes: 3 additions & 2 deletions pkg/dbutil/dbutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"testing"

"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"

_ "github.com/mattn/go-sqlite3" // database/sql driver
Expand All @@ -12,13 +13,13 @@ import (

func TestDatabaseName(t *testing.T) {
t.Run("valid", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host/dbname?query")
u := dbtest.MustParseURL(t, "foo://host/dbname?query")
name := dbutil.DatabaseName(u)
require.Equal(t, "dbname", name)
})

t.Run("empty", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host")
u := dbtest.MustParseURL(t, "foo://host")
name := dbutil.DatabaseName(u)
require.Equal(t, "", name)
})
Expand Down
17 changes: 4 additions & 13 deletions pkg/driver/bigquery/bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,25 @@ import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
"github.com/amacneil/dbmate/v2/pkg/dbtest"
"github.com/amacneil/dbmate/v2/pkg/dbutil"
)

func testBigQueryDriver(t *testing.T) *Driver {
url := os.Getenv("BIGQUERY_TEST_URL")
if url == "" {
t.Skip("no BIGQUERY_TEST_URL provided")
}

u := dbutil.MustParseURL(url)
u := dbtest.GetenvURLOrSkip(t, "BIGQUERY_TEST_URL")
drv, err := dbmate.New(u).Driver()
require.NoError(t, err)

return drv.(*Driver)
}

func testGoogleBigQueryDriver(t *testing.T) *Driver {
testURL := os.Getenv("GOOGLE_BIGQUERY_TEST_URL")
if testURL == "" {
t.Skip("no GOOGLE_BIGQUERY_TEST_URL provided")
}
u := dbutil.MustParseURL(testURL)
u := dbtest.GetenvURLOrSkip(t, "GOOGLE_BIGQUERY_TEST_URL")

endpoint := u.Query().Get("endpoint")
if endpoint != "" {
Expand Down Expand Up @@ -86,7 +77,7 @@ func prepTestGoogleBigQueryDB(t *testing.T) *sql.DB {
}

func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("bigquery://"))
db := dbmate.New(dbtest.MustParseURL(t, "bigquery://"))
drvInterface, err := db.Driver()
require.NoError(t, err)

Expand Down
9 changes: 7 additions & 2 deletions pkg/driver/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewDriver(config dbmate.DriverConfig) dbmate.Driver {

func connectionString(initialURL *url.URL) string {
// clone url
u := dbutil.MustParseURL(initialURL.String())
u, _ := url.Parse(initialURL.String())

host := u.Host
if u.Port() == "" {
Expand Down Expand Up @@ -109,7 +109,12 @@ func (drv *Driver) onClusterClause() string {
}

func (drv *Driver) databaseName() string {
name := strings.TrimLeft(dbutil.MustParseURL(connectionString(drv.databaseURL)).Path, "/")
u, err := url.Parse(connectionString(drv.databaseURL))
if err != nil {
panic(err)
}

name := strings.TrimLeft(u.Path, "/")
if name == "" {
name = "default"
}
Expand Down
Loading