Skip to content

Commit

Permalink
fix: return a friendly error if the dialer is closed (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Apr 9, 2024
1 parent 9684600 commit d1c13e0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
23 changes: 21 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ const (
)

var (
// ErrDialerClosed is used when a caller invokes Dial after closing the
// Dialer.
ErrDialerClosed = errors.New("cloudsqlconn: dialer is closed")
// versionString indicates the version of this library.
//go:embed version.txt
versionString string
Expand Down Expand Up @@ -91,8 +94,11 @@ type Dialer struct {
instances map[instance.ConnName]connectionInfoCache
key *rsa.PrivateKey
refreshTimeout time.Duration
sqladmin *sqladmin.Service
logger debug.Logger
// closed reports if the dialer has been closed.
closed chan struct{}

sqladmin *sqladmin.Service
logger debug.Logger

// defaultDialConfig holds the constructor level DialOptions, so that it
// can be copied and mutated by the Dial function.
Expand Down Expand Up @@ -210,6 +216,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, err
}
d := &Dialer{
closed: make(chan struct{}),
instances: make(map[instance.ConnName]connectionInfoCache),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
Expand All @@ -227,6 +234,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// icn argument must be the instance's connection name, which is in the format
// "project-name:region:instance-name".
func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn net.Conn, err error) {
select {
case <-d.closed:
return nil, ErrDialerClosed
default:
}
startTime := time.Now()
var endDial trace.EndSpanFunc
ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial",
Expand Down Expand Up @@ -420,6 +432,13 @@ func (i *instrumentedConn) Close() error {
// needed to connect. Additional dial operations may succeed until the information
// expires.
func (d *Dialer) Close() error {
// Check if Close has already been called.
select {
case <-d.closed:
return nil
default:
}
close(d.closed)
d.lock.Lock()
defer d.lock.Unlock()
for _, i := range d.instances {
Expand Down
24 changes: 24 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,27 @@ func TestDialerSupportsOneOffDialFunction(t *testing.T) {
t.Fatal("one-off dial func was not called")
}
}

func TestDialerCloseReportsFriendlyError(t *testing.T) {
d, err := NewDialer(
context.Background(),
WithTokenSource(mock.EmptyTokenSource{}),
)
if err != nil {
t.Fatal(err)
}
_ = d.Close()

_, err = d.Dial(context.Background(), "p:r:i")
if !errors.Is(err, ErrDialerClosed) {
t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
}

// Ensure multiple calls to close don't panic
_ = d.Close()

_, err = d.Dial(context.Background(), "p:r:i")
if !errors.Is(err, ErrDialerClosed) {
t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
}
}

0 comments on commit d1c13e0

Please sign in to comment.