Skip to content

Commit

Permalink
fix: return a friendly error if the dialer is closed (#538)
Browse files Browse the repository at this point in the history
If the dialer has already been closed, return a clear error.

Fixes #522
  • Loading branch information
enocom committed Apr 9, 2024
1 parent fe8b433 commit 66d7bd0
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ d, err := alloydbconn.NewDialer(ctx)
if err != nil {
log.Fatalf("failed to initialize dialer: %v", err)
}
// Don't close the dialer until you're done with the database connection
// e.g. at the end of your main function
defer d.Close()

// Tell the driver to use the AlloyDB Go Connector to create connections
Expand Down
2 changes: 1 addition & 1 deletion database_sql_public_ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func connectDatabaseSQLWithPublicIP(
}

db, err := sql.Open(
"alloydb",
"alloydb-public",
fmt.Sprintf(
// sslmode is disabled, because the Dialer will handle the SSL
// connection instead.
Expand Down
19 changes: 19 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ const (
)

var (
// ErrDialerClosed is used when a caller invokes Dial after closing the
// Dialer.
ErrDialerClosed = errors.New("alloydbconn: dialer is closed")
// versionString indicates the version of this library.
//go:embed version.txt
versionString string
Expand Down Expand Up @@ -90,6 +93,8 @@ type Dialer struct {
instances map[alloydb.InstanceURI]connectionInfoCache
key *rsa.PrivateKey
refreshTimeout time.Duration
// closed reports if the dialer has been closed.
closed chan struct{}

client *alloydbadmin.AlloyDBAdminClient
logger debug.Logger
Expand Down Expand Up @@ -174,6 +179,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, err
}
d := &Dialer{
closed: make(chan struct{}),
instances: make(map[alloydb.InstanceURI]connectionInfoCache),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
Expand All @@ -194,6 +200,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// instance argument must be the instance's URI, which is in the format
// projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) (conn net.Conn, err error) {
select {
case <-d.closed:
return nil, ErrDialerClosed
default:
}
startTime := time.Now()
var endDial trace.EndSpanFunc
ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn.Dial",
Expand Down Expand Up @@ -508,6 +519,14 @@ func (i *instrumentedConn) Close() error {
// needed to connect. Additional dial operations may succeed until the information
// expires.
func (d *Dialer) Close() error {
// Check if Close has already been called.
select {
case <-d.closed:
return nil
default:
}
close(d.closed)

d.lock.Lock()
defer d.lock.Unlock()
for _, i := range d.instances {
Expand Down
41 changes: 34 additions & 7 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import (
"google.golang.org/api/option"
)

const testInstanceURI = "projects/my-project/locations/my-region/" +
"clusters/my-cluster/instances/my-instance"

type stubTokenSource struct{}

func (stubTokenSource) Token() (*oauth2.Token, error) {
Expand Down Expand Up @@ -72,7 +75,7 @@ func TestDialerCanConnectToInstance(t *testing.T) {
// reset between connections.
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
conn, err := d.Dial(ctx, "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
conn, err := d.Dial(ctx, testInstanceURI)
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
Expand Down Expand Up @@ -116,12 +119,12 @@ func TestDialWithAdminAPIErrors(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
cancel()

_, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
_, err = d.Dial(ctx, testInstanceURI)
if !errors.Is(err, context.Canceled) {
t.Fatalf("when context is canceled, want = %T, got = %v", context.Canceled, err)
}

_, err = d.Dial(context.Background(), "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
_, err = d.Dial(context.Background(), testInstanceURI)
var wantErr2 *errtype.RefreshError
if !errors.As(err, &wantErr2) {
t.Fatalf("when API call fails, want = %T, got = %v", wantErr2, err)
Expand Down Expand Up @@ -152,7 +155,7 @@ func TestDialWithUnavailableServerErrors(t *testing.T) {
}
d.client = c

_, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
_, err = d.Dial(ctx, testInstanceURI)
var wantErr2 *errtype.DialError
if !errors.As(err, &wantErr2) {
t.Fatalf("when server proxy socket is unavailable, want = %T, got = %v", wantErr2, err)
Expand Down Expand Up @@ -191,7 +194,7 @@ func TestDialerWithCustomDialFunc(t *testing.T) {
}
d.client = c

_, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
_, err = d.Dial(ctx, testInstanceURI)
if !strings.Contains(err.Error(), "sentinel error") {
t.Fatalf("want = sentinel error, got = %v", err)
}
Expand Down Expand Up @@ -275,7 +278,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
}

sentinel := errors.New("connect info failed")
inst := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
inst := testInstanceURI
cn, _ := alloydb.ParseInstURI(inst)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
Expand Down Expand Up @@ -410,8 +413,32 @@ func TestDialerSupportsOneOffDialFunction(t *testing.T) {
return nil, sentinelErr
}

_, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance", WithOneOffDialFunc(f))
_, err = d.Dial(ctx, testInstanceURI, WithOneOffDialFunc(f))
if !errors.Is(err, sentinelErr) {
t.Fatal("one-off dial func was not called")
}
}

func TestDialerCloseReportsFriendlyError(t *testing.T) {
d, err := NewDialer(
context.Background(),
WithTokenSource(stubTokenSource{}),
)
if err != nil {
t.Fatal(err)
}
_ = d.Close()

_, err = d.Dial(context.Background(), testInstanceURI)
if !errors.Is(err, ErrDialerClosed) {
t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
}

// Ensure multiple calls to close don't panic
_ = d.Close()

_, err = d.Dial(context.Background(), testInstanceURI)
if !errors.Is(err, ErrDialerClosed) {
t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
}
}

0 comments on commit 66d7bd0

Please sign in to comment.