Skip to content

Commit

Permalink
feat: add support for a custom dialer (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed May 5, 2022
1 parent 9d5512d commit 0ac17a5
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 49 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ jobs:
- run: goimports -w .
- run: go mod tidy
- name: Verify no changes from goimports and go mod tidy. If you're reading this and the check has failed, run `goimports -w . && go mod tidy`.
run: git diff --exit-code
34 changes: 34 additions & 0 deletions cloudsql/cloudsql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cloudsql

import (
"context"
"io"
"net"

"cloud.google.com/go/cloudsqlconn"
)

// Dialer dials a Cloud SQL instance and returns its database engine version.
type Dialer interface {
// Dial returns a connection to the specified instance.
Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error)
// EngineVersion retrieves the provided instance's database version (e.g.,
// POSTGRES_14)
EngineVersion(ctx context.Context, inst string) (string, error)

io.Closer
}
2 changes: 1 addition & 1 deletion cmd/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var (
}

errSigTerm = &exitError{
Err: errors.New("SIGINT signal received"),
Err: errors.New("SIGTERM signal received"),
Code: 137,
}
)
Expand Down
33 changes: 27 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"syscall"

"cloud.google.com/go/cloudsqlconn"
"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql"
"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -62,11 +63,25 @@ type Command struct {
conf *proxy.Config
}

// Option is a function that configures a Command.
type Option func(*proxy.Config)

// WithDialer configures the Command to use the provided dialer to connect to
// Cloud SQL instances.
func WithDialer(d cloudsql.Dialer) Option {
return func(c *proxy.Config) {
c.Dialer = d
}
}

// NewCommand returns a Command object representing an invocation of the proxy.
func NewCommand() *Command {
func NewCommand(opts ...Option) *Command {
c := &Command{
conf: &proxy.Config{},
}
for _, o := range opts {
o(c.conf)
}

cmd := &cobra.Command{
Use: "cloud_sql_proxy instance_connection_name...",
Expand Down Expand Up @@ -192,11 +207,17 @@ func runSignalWrapper(cmd *Command) error {
startCh := make(chan *proxy.Client)
go func() {
defer close(startCh)
opts := append(cmd.conf.DialerOpts(), cloudsqlconn.WithUserAgent(userAgent))
d, err := cloudsqlconn.NewDialer(ctx, opts...)
if err != nil {
shutdownCh <- fmt.Errorf("error initializing dialer: %v", err)
return
// Check if the caller has configured a dialer.
// Otherwise, initialize a new one.
d := cmd.conf.Dialer
if d == nil {
var err error
opts := append(cmd.conf.DialerOpts(), cloudsqlconn.WithUserAgent(userAgent))
d, err = cloudsqlconn.NewDialer(ctx, opts...)
if err != nil {
shutdownCh <- fmt.Errorf("error initializing dialer: %v", err)
return
}
}
p, err := proxy.NewClient(ctx, d, cmd.Command, cmd.conf)
if err != nil {
Expand Down
57 changes: 56 additions & 1 deletion cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
package cmd

import (
"context"
"errors"
"net"
"sync"
"testing"
"time"

"cloud.google.com/go/cloudsqlconn"
"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -131,7 +138,7 @@ func TestNewCommandArguments(t *testing.T) {
t.Fatalf("want error = nil, got = %v", err)
}

if got := c.conf; !cmp.Equal(tc.want, got) {
if got := c.conf; !cmp.Equal(tc.want, got, cmpopts.IgnoreUnexported(proxy.Config{})) {
t.Fatalf("want = %#v\ngot = %#v\ndiff = %v", tc.want, got, cmp.Diff(tc.want, got))
}
})
Expand Down Expand Up @@ -200,3 +207,51 @@ func TestNewCommandWithErrors(t *testing.T) {
})
}
}

type spyDialer struct {
mu sync.Mutex
got string
}

func (s *spyDialer) instance() string {
s.mu.Lock()
defer s.mu.Unlock()
i := s.got
return i
}

func (*spyDialer) Dial(_ context.Context, inst string, _ ...cloudsqlconn.DialOption) (net.Conn, error) {
return nil, errors.New("spy dialer does not dial")
}

func (s *spyDialer) EngineVersion(ctx context.Context, inst string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.got = inst
return "", nil
}

func (*spyDialer) Close() error {
return nil
}

func TestCommandWithCustomDialer(t *testing.T) {
want := "my-project:my-region:my-instance"
s := &spyDialer{}
c := NewCommand(WithDialer(s))
// Keep the test output quiet
c.SilenceUsage = true
c.SilenceErrors = true
c.SetArgs([]string{want})

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

if err := c.ExecuteContext(ctx); !errors.As(err, &errSigInt) {
t.Fatalf("want errSigInt, got = %v", err)
}

if got := s.instance(); got != want {
t.Fatalf("want = %v, got = %v", want, got)
}
}
46 changes: 23 additions & 23 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"time"

"cloud.google.com/go/cloudsqlconn"
"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
)
Expand All @@ -47,16 +48,26 @@ type Config struct {
// Addr is the address on which to bind all instances.
Addr string

// Port is the first port to bind to. Subsequent ports will increment from
// this value.
// Port is the initial port to bind to. Subsequent instances bind to
// increments from this value.
Port int

// Instances are configuration for individual instances. Instance
// configuration takes precedence over global configuration.
Instances []InstanceConnConfig

// Dialer specifies the dialer to use when connecting to Cloud SQL
// instances.
Dialer cloudsql.Dialer
}

// NewConfig initializes a Config struct using the default database engine
// ports.
func NewConfig() *Config {
return &Config{}
}

func (c Config) DialerOpts() []cloudsqlconn.Option {
func (c *Config) DialerOpts() []cloudsqlconn.Option {
var opts []cloudsqlconn.Option
if c.Token != "" {
opts = append(opts, cloudsqlconn.WithTokenSource(
Expand All @@ -66,15 +77,6 @@ func (c Config) DialerOpts() []cloudsqlconn.Option {
return opts
}

// Client represents the state of the current instantiation of the proxy.
type Client struct {
cmd *cobra.Command
dialer Dialer

// mnts is a list of all mounted sockets for this client
mnts []*socketMount
}

type portConfig struct {
global int
postgres int
Expand Down Expand Up @@ -118,27 +120,25 @@ func (c *portConfig) nextDBPort(version string) int {
}
}

// Dialer dials a Cloud SQL instance and returns its database engine version.
type Dialer interface {
// Dial returns a connection to the specified instance.
Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error)
// EngineVersion retrieves the provided instance's database version (e.g.,
// POSTGRES_14)
EngineVersion(ctx context.Context, inst string) (string, error)
// Close terminates all background operations and stops the dialer.
Close() error
// Client represents the state of the current instantiation of the proxy.
type Client struct {
cmd *cobra.Command
dialer cloudsql.Dialer

// mnts is a list of all mounted sockets for this client
mnts []*socketMount
}

// NewClient completes the initial setup required to get the proxy to a "steady" state.
func NewClient(ctx context.Context, d Dialer, cmd *cobra.Command, conf *Config) (*Client, error) {
func NewClient(ctx context.Context, d cloudsql.Dialer, cmd *cobra.Command, conf *Config) (*Client, error) {
var mnts []*socketMount
pc := newPortConfig(conf.Port)
for _, inst := range conf.Instances {
go func(name string) {
// Initiate refresh operation
d.EngineVersion(ctx, name)
}(inst.Name)
}
pc := newPortConfig(conf.Port)
for _, inst := range conf.Instances {
m := &socketMount{inst: inst.Name}
a := conf.Addr
Expand Down
24 changes: 6 additions & 18 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"strings"
"testing"

"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql"
"github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy"
"github.com/spf13/cobra"
)

type fakeDialer struct {
proxy.Dialer
cloudsql.Dialer
}

func (fakeDialer) Close() error {
Expand Down Expand Up @@ -122,22 +123,6 @@ func TestClientInitialization(t *testing.T) {
"127.0.0.1:5001",
},
},
{
desc: "with automatic port selection",
in: &proxy.Config{
Addr: "127.0.0.1",
Instances: []proxy.InstanceConnConfig{
{Name: pg},
{Name: mysql},
{Name: sqlserver},
},
},
wantAddrs: []string{
"127.0.0.1:5432",
"127.0.0.1:3306",
"127.0.0.1:1433",
},
},
{
desc: "with incrementing automatic port selection",
in: &proxy.Config{
Expand Down Expand Up @@ -174,7 +159,10 @@ func TestClientInitialization(t *testing.T) {
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
defer conn.Close()
err = conn.Close()
if err != nil {
t.Logf("failed to close connection: %v", err)
}
}
})
}
Expand Down

0 comments on commit 0ac17a5

Please sign in to comment.