Skip to content

Commit

Permalink
Merge d789016 into 079ba81
Browse files Browse the repository at this point in the history
  • Loading branch information
nockty committed Jun 21, 2019
2 parents 079ba81 + d789016 commit 63188a8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
12 changes: 7 additions & 5 deletions driver.go
Expand Up @@ -13,14 +13,16 @@ import (
)

func init() {
sql.Register("pq-timeouts", timeoutDriver{dialOpen: pq.DialOpen})
sql.Register("pq-timeouts", &TimeoutDriver{DialOpen: pq.DialOpen})
}

type timeoutDriver struct {
dialOpen func(pq.Dialer, string) (driver.Conn, error) // Allow this to be stubbed for testing
// TimeoutDriver is the Postgres database driver, providing read and write timeouts.
type TimeoutDriver struct {
DialOpen func(pq.Dialer, string) (driver.Conn, error) // Allow this to be stubbed for testing
}

func (t timeoutDriver) Open(connection string) (_ driver.Conn, err error) {
// Open creates a new connection to the database by using the given connection string.
func (t TimeoutDriver) Open(connection string) (_ driver.Conn, err error) {
// Look for read_timeout and write_timeout in the connection string and extract the values.
// read_timeout and write_timeout need to be removed from the connection string before calling pq as well.
var newConnectionSettings []string
Expand Down Expand Up @@ -56,7 +58,7 @@ func (t timeoutDriver) Open(connection string) (_ driver.Conn, err error) {

newConnectionStr := strings.Join(newConnectionSettings, " ")

return t.dialOpen(
return t.DialOpen(
timeoutDialer{
netDial: net.Dial,
netDialTimeout: net.DialTimeout,
Expand Down
14 changes: 7 additions & 7 deletions driver_test.go
Expand Up @@ -19,7 +19,7 @@ func TestOpenNoTimeoutsAdded(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

testConnection := "user=pqtest dbname=pqtest sslmode=verify-full"
_, err := driver.Open(testConnection)
Expand Down Expand Up @@ -47,7 +47,7 @@ func TestOpenReadTimeoutAdded(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open("user=pqtest read_timeout=700 dbname=pqtest sslmode=verify-full")

Expand Down Expand Up @@ -82,7 +82,7 @@ func TestOpenWriteTimeoutAdded(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open(" user=pqtest write_timeout=968 dbname=pqtest sslmode=verify-full ")

Expand Down Expand Up @@ -115,7 +115,7 @@ func TestOpenTimeoutsAddedWriteError(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open(" user=pqtest write_timeout=seven read_timeout=7 dbname=pqtest sslmode=verify-full ")

Expand All @@ -140,7 +140,7 @@ func TestOpenTimeoutsAddedReadError(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open(" user=pqtest write_timeout=680 read_timeout= dbname=pqtest sslmode=verify-full ")

Expand All @@ -167,7 +167,7 @@ func TestPostgresURL(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open("postgres://pqtest:password@localhost/pqtest?read_timeout=500&sslmode=verify-full&write_timeout=100")

Expand Down Expand Up @@ -201,7 +201,7 @@ func TestPostgresqlURLError(t *testing.T) {
return nil, nil
}

driver := timeoutDriver{dialOpen: testDialOpen}
driver := TimeoutDriver{DialOpen: testDialOpen}

_, err := driver.Open("postgresql://pqtest\\\\/:password@localhost/pqtest?read_timeout=500&sslmode=verify-full&write_timeout=100")

Expand Down

0 comments on commit 63188a8

Please sign in to comment.