Skip to content

Commit

Permalink
feat: add support for dialer (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Mar 31, 2022
1 parent adf7975 commit 483ffda
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 360 deletions.
62 changes: 12 additions & 50 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ import (
"time"

"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/internal/alloydb"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

const (
Expand All @@ -44,7 +42,7 @@ const (
// defaultTCPKeepAlive is the default keep alive value used on connections to a Cloud SQL instance.
defaultTCPKeepAlive = 30 * time.Second
// serverProxyPort is the port the server-side proxy receives connections on.
serverProxyPort = "3307"
serverProxyPort = "5433"
)

var (
Expand Down Expand Up @@ -72,7 +70,7 @@ type Dialer struct {
key *rsa.PrivateKey
refreshTimeout time.Duration

sqladmin *sqladmin.Service
client *alloydb.Client

// defaultDialCfg holds the constructor level DialOptions, so that it can
// be copied and mutated by the Dial function.
Expand All @@ -85,10 +83,6 @@ type Dialer struct {
// dialFunc is the function used to connect to the address on the named
// network. By default it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn. If IAM DB
// Authn is not enabled, iamTokenSource will be nil.
iamTokenSource oauth2.TokenSource
}

// NewDialer creates a new Dialer.
Expand All @@ -99,7 +93,7 @@ type Dialer struct {
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: 30 * time.Second,
sqladminOpts: []option.ClientOption{option.WithUserAgent(userAgent)},
adminOpts: []option.ClientOption{option.WithUserAgent(userAgent)},
dialFunc: proxy.Dial,
}
for _, opt := range opts {
Expand All @@ -108,21 +102,6 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, cfg.err
}
}
// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc, then use the
// default token source.
if cfg.useIAMAuthN && cfg.tokenSource == nil {
ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to create token source: %v", err)
}
cfg.tokenSource = ts
}
// If IAM Authn is not explicitly enabled, remove the token source.
if !cfg.useIAMAuthN {
cfg.tokenSource = nil
}

if cfg.rsaKey == nil {
key, err := getDefaultKeys()
if err != nil {
Expand All @@ -131,13 +110,12 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg.rsaKey = key
}

client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...)
client, err := alloydb.NewClient(ctx, cfg.adminOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
return nil, fmt.Errorf("failed to create AlloyDB Admin API client: %v", err)
}

dialCfg := dialCfg{
ipType: "PUBLIC",
tcpKeepAlive: defaultTCPKeepAlive,
}
for _, opt := range cfg.dialOpts {
Expand All @@ -151,10 +129,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
instances: make(map[string]*cloudsql.Instance),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
sqladmin: client,
client: client,
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
iamTokenSource: cfg.tokenSource,
dialFunc: cfg.dialFunc,
}
return d, nil
Expand All @@ -165,7 +142,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) (conn net.Conn, err error) {
startTime := time.Now()
var endDial trace.EndSpanFunc
ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial",
ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn.Dial",
trace.AddInstanceName(instance),
trace.AddDialerID(d.dialerID),
)
Expand All @@ -179,21 +156,21 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}

var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.InstanceInfo")
i, err := d.instance(instance)
if err != nil {
endInfo(err)
return nil, err
}
addr, tlsCfg, err := i.ConnectInfo(ctx, cfg.ipType)
addr, tlsCfg, err := i.ConnectInfo(ctx)
if err != nil {
endInfo(err)
return nil, err
}
endInfo(err)

var connectEnd trace.EndSpanFunc
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.Connect")
defer func() { connectEnd(err) }()
addr = net.JoinHostPort(addr, serverProxyPort)
conn, err = d.dialFunc(ctx, "tcp", addr)
Expand Down Expand Up @@ -230,21 +207,6 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}), nil
}

// EngineVersion returns the engine type and version for the instance. The value will
// corespond to one of the following types for the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (d *Dialer) EngineVersion(ctx context.Context, instance string) (string, error) {
i, err := d.instance(instance)
if err != nil {
return "", err
}
e, err := i.InstanceEngineVersion(ctx)
if err != nil {
return "", err
}
return e, nil
}

// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
Expand Down Expand Up @@ -296,7 +258,7 @@ func (d *Dialer) instance(connName string) (*cloudsql.Instance, error) {
if !ok {
// Create a new instance
var err error
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource, d.dialerID)
i, err = cloudsql.NewInstance(connName, d.client, d.key, d.refreshTimeout, d.dialerID)
if err != nil {
d.lock.Unlock()
return nil, err
Expand Down

0 comments on commit 483ffda

Please sign in to comment.