Skip to content

Commit

Permalink
feat: Add EngineVersion method to Dialer (#59)
Browse files Browse the repository at this point in the history
* feat: Add EngineVersion method to Dialer

* chore(style): remove named returned values
  • Loading branch information
kurtisvg authored Jan 7, 2022
1 parent 4cb523e commit 6a78bfd
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 4 deletions.
15 changes: 15 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,21 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
return newInstrumentedConn(tlsConn, instance, d.dialerID), nil
}

// EngineVersion returns the engine type and version for the instance. The value will
// corespond to one of the following types for the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (d *Dialer) EngineVersion(ctx context.Context, instance string) (string, error) {
i, err := d.instance(instance)
if err != nil {
return "", err
}
e, err := i.InstanceEngineVersion(ctx)
if err != nil {
return "", err
}
return e, nil
}

// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, instance, dialerID string) *instrumentedConn {
Expand Down
51 changes: 51 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func TestDialerCanConnectToInstance(t *testing.T) {
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
stop := mock.StartServerProxy(t, inst)
defer func() {
stop()
Expand Down Expand Up @@ -69,6 +72,9 @@ func TestDialerCanConnectToInstance(t *testing.T) {
func TestDialWithAdminAPIErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(context.Background())
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
stop := mock.StartServerProxy(t, inst)
defer func() {
stop()
Expand Down Expand Up @@ -115,6 +121,9 @@ func TestDialWithConfigurationErrors(t *testing.T) {
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(),
WithDefaultDialOptions(WithPublicIP()),
WithTokenSource(mock.EmptyTokenSource{}),
Expand Down Expand Up @@ -199,6 +208,9 @@ func TestDialerWithCustomDialFunc(t *testing.T) {
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(),
WithTokenSource(mock.EmptyTokenSource{}),
WithDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
Expand All @@ -220,3 +232,42 @@ func TestDialerWithCustomDialFunc(t *testing.T) {
t.Fatalf("want = sentinel error, got = %v", err)
}
}

func TestDialerEngineVersion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tests := []string{
"MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18",
}
for _, wantEV := range tests {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithEngineVersion(wantEV))
svc, cleanup, err := mock.NewSQLAdminService(
ctx,
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(),
WithTokenSource(mock.EmptyTokenSource{}),
)
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}
d.sqladmin = svc
defer func() {
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()

gotEV, err := d.EngineVersion(ctx, "my-project:my-region:my-instance")
if err != nil {
t.Fatalf("failed to retrieve engine version: %v", err)
}
if wantEV != gotEV {
t.Errorf("InstanceEngineVersion(%s) failed: want %v, got %v", wantEV, gotEV, err)
}
}
}
17 changes: 17 additions & 0 deletions e2e_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"net"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -114,3 +115,19 @@ func TestConnectWithIAMUser(t *testing.T) {
}
t.Log(now)
}

func TestEngineVersion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
d, err := cloudsqlconn.NewDialer(context.Background())
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}
gotEV, err := d.EngineVersion(ctx, postgresConnName)
if err != nil {
t.Fatalf("failed to retrieve engine version: %v", err)
}
if !strings.Contains(gotEV, "POSTGRES") {
t.Errorf("InstanceEngineVersion(%s) failed: want 'POSTGRES', got %v", gotEV, err)
}
}
28 changes: 24 additions & 4 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,7 @@ func (i *Instance) Close() {
// private) and a TLS config that can be used to connect to a Cloud SQL
// instance.
func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls.Config, error) {
i.resultGuard.RLock()
res := i.cur
i.resultGuard.RUnlock()
err := res.Wait(ctx)
res, err := i.result(ctx)
if err != nil {
return "", nil, err
}
Expand All @@ -200,6 +197,17 @@ func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls
return addr, res.tlsCfg, nil
}

// InstanceEngineVersion returns the engine type and version for the instance. The value
// coresponds to one of the following types for the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) {
res, err := i.result(ctx)
if err != nil {
return "", err
}
return res.md.version, nil
}

// ForceRefresh triggers an immediate refresh operation to be scheduled and used for future connection attempts.
func (i *Instance) ForceRefresh() {
i.resultGuard.Lock()
Expand All @@ -212,6 +220,18 @@ func (i *Instance) ForceRefresh() {
i.cur = i.next
}

// result returns the most recent refresh result (waiting for it to complete if necessary)
func (i *Instance) result(ctx context.Context) (*refreshResult, error) {
i.resultGuard.RLock()
res := i.cur
i.resultGuard.RUnlock()
err := res.Wait(ctx)
if err != nil {
return nil, err
}
return res, nil
}

// scheduleRefresh schedules a refresh operation to be triggered after a given duration. The returned refreshResult
// can be used to either Cancel or Wait for the operations result.
func (i *Instance) scheduleRefresh(d time.Duration) *refreshResult {
Expand Down
37 changes: 37 additions & 0 deletions internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,43 @@ func TestParseConnName(t *testing.T) {
}
}

func TestInstanceEngineVersion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tests := []string{
"MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18",
}
for _, wantEV := range tests {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithEngineVersion(wantEV))
client, cleanup, err := mock.NewSQLAdminService(
ctx,
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("%s", err)
}
defer func() {
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()
i, err := NewInstance("my-project:my-region:my-instance", client, RSAKey, 30*time.Second, nil)
if err != nil {
t.Fatalf("failed to init instance: %v", err)
}

gotEV, err := i.InstanceEngineVersion(ctx)
if err != nil {
t.Fatalf("failed to retrieve engine version: %v", err)
}
if wantEV != gotEV {
t.Errorf("InstanceEngineVersion(%s) failed: want %v, got %v", wantEV, gotEV, err)
}

}
}

func TestConnectInfo(t *testing.T) {
ctx := context.Background()
wantAddr := "0.0.0.0"
Expand Down
3 changes: 3 additions & 0 deletions internal/cloudsql/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ func TestRefresh(t *testing.T) {
wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
wantConnName := "my-project:my-region:my-instance"
cn, err := parseConnName(wantConnName)
if err != nil {
t.Fatalf("parseConnName(%s)failed : %v", cn, err)
}
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithPublicIP(wantPublicIP),
Expand Down
7 changes: 7 additions & 0 deletions internal/mock/cloudsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ func WithFirstGenBackend() FakeCSQLInstanceOption {
}
}

// WithEngineVersion sets the "DB Version"
func WithEngineVersion(s string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.dbVersion = s
}
}

// SignFunc is a function that signs the certificate using the provided key. The
// result should be PEM-encoded.
type SignFunc = func(*x509.Certificate, *rsa.PrivateKey) ([]byte, error)
Expand Down

0 comments on commit 6a78bfd

Please sign in to comment.