Skip to content

Commit

Permalink
feat: add DialOptions for configuring Dial (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtisvg committed Apr 9, 2021
1 parent 1235a9f commit e2d53ee
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 50 deletions.
73 changes: 62 additions & 11 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"fmt"
"net"
"sync"
"time"

"cloud.google.com/cloudsqlconn/internal/cloudsql"
apiopt "google.golang.org/api/option"
"golang.org/x/net/proxy"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

type dialerConfig struct {
sqladminOpts []apiopt.ClientOption
}
const (
// defaultTCPKeepAlive is the default keep alive value used on connections to a Cloud SQL instance.
defaultTCPKeepAlive = 30 * time.Second
// serverProxyPort is the port the server-side proxy receives connections on.
serverProxyPort = "3307"
)

// A Dialer is used to create connections to Cloud SQL instances.
type Dialer struct {
Expand All @@ -38,6 +43,10 @@ type Dialer struct {
key *rsa.PrivateKey

sqladmin *sqladmin.Service

// defaultDialCfg holds the constructor level DialOptions, so that it can
// be copied and mutated by the Dial function.
defaultDialCfg dialCfg
}

// NewDialer creates a new Dialer.
Expand All @@ -52,26 +61,68 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
for _, opt := range opts {
opt(cfg)
}

client, err := sqladmin.NewService(context.Background(), cfg.sqladminOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
}

dialCfg := dialCfg{
ipType: cloudsql.PublicIP,
tcpKeepAlive: defaultTCPKeepAlive,
}
for _, opt := range cfg.dialOpts {
opt(&dialCfg)
}

d := &Dialer{
instances: make(map[string]*cloudsql.Instance),
sqladmin: client,
key: key,
instances: make(map[string]*cloudsql.Instance),
sqladmin: client,
key: key,
defaultDialCfg: dialCfg,
}
return d, nil
}

// Dial creates an authorized connection to a Cloud SQL instance specified by it's instance connection name.
func (d *Dialer) Dial(ctx context.Context, instance string) (net.Conn, error) {
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The instance argument must be the
// instance's connection name, which is in the format "project-name:region:instance-name".
func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) (net.Conn, error) {
cfg := d.defaultDialCfg
for _, opt := range opts {
opt(&cfg)
}

i, err := d.instance(instance)
if err != nil {
return nil, err
}
return i.Connect(ctx)
ipAddrs, tlsCfg, err := i.ConnectInfo(ctx)
if err != nil {
return nil, err
}
addr, ok := ipAddrs[cfg.ipType]
if !ok {
return nil, fmt.Errorf("instance '%s' does not have IP of type '%s'", instance, cfg.ipType)
}
addr = net.JoinHostPort(addr, serverProxyPort)

conn, err := proxy.Dial(ctx, "tcp", addr)
if err != nil {
return nil, err
}
if c, ok := conn.(*net.TCPConn); ok {
if err := c.SetKeepAlive(true); err != nil {
return nil, fmt.Errorf("failed to set keep-alive: %v", err)
}
if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil {
return nil, fmt.Errorf("failed to set keep-alive period: %v", err)
}
}
tlsConn := tls.Client(conn, tlsCfg)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("handshake failed: %w", err)
}
return tlsConn, err
}

func (d *Dialer) instance(connName string) (*cloudsql.Instance, error) {
Expand Down
34 changes: 4 additions & 30 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,17 @@ import (
"crypto/rsa"
"crypto/tls"
"fmt"
"net"
"regexp"
"sync"
"time"

"golang.org/x/net/proxy"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

var (
// Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
// Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT")
connNameRegex = regexp.MustCompile("([^:]+(:[^:]+)?):([^:]+):([^:]+)")

// defaultKeepAlive is the default keep alive value on connections created in this package.
defaultKeepAlive = 30 * time.Second
)

// connName represents the "instance connection name", in the format "project:region:name". Use the
Expand Down Expand Up @@ -133,37 +128,16 @@ func NewInstance(instance string, client *sqladmin.Service, key *rsa.PrivateKey)
return i, nil
}

// Connect returns a secure, authorized net.Conn to a Cloud SQL instance.
func (i *Instance) Connect(ctx context.Context) (net.Conn, error) {
// ConnectInfo returns a map of IP types and a TLS config that can be used to connect to a Cloud SQL instance.
func (i *Instance) ConnectInfo(ctx context.Context) (map[string]string, *tls.Config, error) {
i.resultGuard.RLock()
res := i.cur
i.resultGuard.RUnlock()
err := res.Wait(ctx)
if err != nil {
return nil, err
}

// TODO: Add better ipType support, including an opt to specify.
addr := net.JoinHostPort(res.md.ipAddrs["PUBLIC"], "3307")
conn, err := proxy.Dial(ctx, "tcp", addr)
if err != nil {
return nil, err
}
if c, ok := conn.(*net.TCPConn); ok {
if err := c.SetKeepAlive(true); err != nil {
return nil, fmt.Errorf("failed to set keep-alive: %w", err)
}
if err := c.SetKeepAlivePeriod(defaultKeepAlive); err != nil {
return nil, fmt.Errorf("failed to set keep-alive period: %w", err)
}
}

tlsConn := tls.Client(conn, res.tlsCfg)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("handshake failed: %w", err)
return nil, nil, err
}
return tlsConn, err
return res.md.ipAddrs, res.tlsCfg, nil
}

// scheduleRefresh schedules a refresh operation to be triggered after a given duration. The returned refreshResult
Expand Down
9 changes: 4 additions & 5 deletions internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestParseConnName(t *testing.T) {
}
}

func TestConnect(t *testing.T) {
func TestConnectInfo(t *testing.T) {
ctx := context.Background()

client, err := sqladmin.NewService(ctx)
Expand All @@ -74,12 +74,11 @@ func TestConnect(t *testing.T) {

im, err := NewInstance(instConnName, client, key)
if err != nil {
t.Fatalf("failed to initialize Instance Manager: %v", err)
t.Fatalf("failed to initialize Instance: %v", err)
}

conn, err := im.Connect(ctx)
_, _, err = im.ConnectInfo(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
t.Fatalf("failed to retrieve connect info: %v", err)
}
conn.Close()
}
11 changes: 8 additions & 3 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ import (
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

const (
PublicIP = "PUBLIC"
PrivateIP = "PRIVATE"
)

// metadata contains information about a Cloud SQL instance needed to create connections.
type metadata struct {
ipAddrs map[string]string
Expand Down Expand Up @@ -54,9 +59,9 @@ func fetchMetadata(ctx context.Context, client *sqladmin.Service, inst connName)
for _, ip := range db.IpAddresses {
switch ip.Type {
case "PRIMARY":
ipAddrs["PUBLIC"] = ip.IpAddress
ipAddrs[PublicIP] = ip.IpAddress
case "PRIVATE":
ipAddrs["PRIVATE"] = ip.IpAddress
ipAddrs[PrivateIP] = ip.IpAddress
}
}
if len(ipAddrs) == 0 {
Expand Down Expand Up @@ -84,7 +89,7 @@ func fetchMetadata(ctx context.Context, client *sqladmin.Service, inst connName)

// fetchEphemeralCert uses the Cloud SQL Admin API's createEphemeral method to create a signed TLS
// certificate that authorized to connect via the Cloud SQL instance's serverside proxy. The cert
// if valid for aproximately one hour.
// if valid for approximately one hour.
func fetchEphemeralCert(ctx context.Context, client *sqladmin.Service, inst connName, key *rsa.PrivateKey) (tls.Certificate, error) {
clientPubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
if err != nil {
Expand Down
55 changes: 54 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
package cloudsqlconn

import (
"time"

"cloud.google.com/cloudsqlconn/internal/cloudsql"
"golang.org/x/oauth2"
apiopt "google.golang.org/api/option"
)

// A DialerOption is an option for configuring a Dialer.
type DialerOption func(d *dialerConfig)

// DialerOptions turns a list of DialerOption instances into a DialerOption.
type dialerConfig struct {
sqladminOpts []apiopt.ClientOption
dialOpts []DialOption
}

// DialerOptions turns a list of DialerOption instances into an DialerOption.
func DialerOptions(opts ...DialerOption) DialerOption {
return func(d *dialerConfig) {
for _, opt := range opts {
Expand All @@ -45,9 +53,54 @@ func WithCredentialsJSON(p []byte) DialerOption {
}
}

// WithDefaultDialOption returns a DialerOption that specifies the default DialOptions used.
func WithDefaultDialOptions(opts ...DialOption) DialerOption {
return func(d *dialerConfig) {
d.dialOpts = append(d.dialOpts, opts...)
}
}

// WithTokenSource returns a DialerOption that specifies an OAuth2 token source to be used as the basis for authentication.
func WithTokenSource(s oauth2.TokenSource) DialerOption {
return func(d *dialerConfig) {
d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(s))
}
}

// A DialOption is an option for configuring how a Dialer's Dial call is executed.
type DialOption func(d *dialCfg)

type dialCfg struct {
tcpKeepAlive time.Duration
ipType string
}

// DialOptions turns a list of DialOption instances into an DialOption.
func DialOptions(opts ...DialOption) DialOption {
return func(cfg *dialCfg) {
for _, opt := range opts {
opt(cfg)
}
}
}

// WithTCPKeepAlive returns a DialOption that specifies the tcp keep alive period for the connection returned by Dial.
func WithTCPKeepAlive(d time.Duration) DialOption {
return func(cfg *dialCfg) {
cfg.tcpKeepAlive = d
}
}

// WithPublicIP returns a DialOption that specifies a public IP will be used to connect.
func WithPublicIP() DialOption {
return func(cfg *dialCfg) {
cfg.ipType = cloudsql.PublicIP
}
}

// WithPrivateIP returns a DialOption that specifies a private IP (VPC) will be used to connect.
func WithPrivateIP() DialOption {
return func(cfg *dialCfg) {
cfg.ipType = cloudsql.PrivateIP
}
}

0 comments on commit e2d53ee

Please sign in to comment.