Skip to content

Commit

Permalink
feat: add support for WithOneOffDialFunc (#558)
Browse files Browse the repository at this point in the history
Fixes #551.
  • Loading branch information
enocom committed Jun 6, 2023
1 parent fbdf37c commit 14592f3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
6 changes: 5 additions & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
defer func() { connectEnd(err) }()
addr = net.JoinHostPort(addr, serverProxyPort)
conn, err = d.dialFunc(ctx, "tcp", addr)
f := d.dialFunc
if cfg.dialFunc != nil {
f = cfg.dialFunc
}
conn, err = f(ctx, "tcp", addr)
if err != nil {
// refresh the instance info in case it caused the connection failure
i.ForceRefresh()
Expand Down
33 changes: 33 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,36 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
t.Fatal("performRefresh should not be running")
}
}

func TestDialerSupportsOneOffDialFunction(t *testing.T) {
ctx := context.Background()
inst := mock.NewFakeCSQLInstance("p", "r", "i")
svc, cleanup, err := mock.NewSQLAdminService(
context.Background(),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(ctx, WithTokenSource(mock.EmptyTokenSource{}))
if err != nil {
t.Fatal(err)
}
d.sqladmin = svc
defer func() {
if err := d.Close(); err != nil {
t.Log(err)
}
_ = cleanup()
}()

sentinelErr := errors.New("dial func was called")
f := func(context.Context, string, string) (net.Conn, error) {
return nil, sentinelErr
}

if _, err := d.Dial(ctx, "p:r:i", WithOneOffDialFunc(f)); !errors.Is(err, sentinelErr) {
t.Fatal("one-off dial func was not called")
}
}
18 changes: 14 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func WithQuotaProject(p string) Option {

// WithDialFunc configures the function used to connect to the address on the
// named network. This option is generally unnecessary except for advanced
// use-cases.
// use-cases. The function is used for all invocations of Dial. To configure
// a dial function per individual calls to dial, use WithOneOffDialFunc.
func WithDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(d *dialerConfig) {
d.dialFunc = dial
Expand All @@ -212,10 +213,10 @@ func WithIAMAuthN() Option {
type DialOption func(d *dialCfg)

type dialCfg struct {
tcpKeepAlive time.Duration
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
ipType string

refreshCfg cloudsql.RefreshCfg
tcpKeepAlive time.Duration
refreshCfg cloudsql.RefreshCfg
}

// DialOptions turns a list of DialOption instances into an DialOption.
Expand All @@ -227,6 +228,15 @@ func DialOptions(opts ...DialOption) DialOption {
}
}

// WithOneOffDialFunc configures the dial function on a one-off basis for an
// individual call to Dial. To configure a dial function across all invocations
// of Dial, use WithDialFunc.
func WithOneOffDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) DialOption {
return func(c *dialCfg) {
c.dialFunc = dial
}
}

// WithTCPKeepAlive returns a DialOption that specifies the tcp keep alive period for the connection returned by Dial.
func WithTCPKeepAlive(d time.Duration) DialOption {
return func(cfg *dialCfg) {
Expand Down

0 comments on commit 14592f3

Please sign in to comment.