Skip to content

Commit

Permalink
fix: support MySQL driver’s conn check. (#226)
Browse files Browse the repository at this point in the history
Fixes #225.
  • Loading branch information
enocom committed Jun 16, 2022
1 parent 6490a5b commit 4b48e3b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
19 changes: 17 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"crypto/rsa"
"crypto/tls"
_ "embed"
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

"cloud.google.com/go/cloudsqlconn/errtype"
Expand Down Expand Up @@ -230,7 +232,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
trace.RecordDialLatency(ctx, instance, d.dialerID, latency)
}()

return newInstrumentedConn(tlsConn, func() {
return newInstrumentedConn(conn, tlsConn, func() {
n := atomic.AddUint64(&i.OpenConns, ^uint64(0))
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, i.String())
}), nil
Expand Down Expand Up @@ -264,9 +266,10 @@ func (d *Dialer) Warmup(ctx context.Context, instance string, opts ...DialOption

// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
func newInstrumentedConn(rawConn, conn net.Conn, closeFunc func()) *instrumentedConn {
return &instrumentedConn{
Conn: conn,
rawConn: rawConn,
closeFunc: closeFunc,
}
}
Expand All @@ -275,9 +278,21 @@ func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
// is closed.
type instrumentedConn struct {
net.Conn
// rawConn is the underlying net.Conn without TLS
rawConn net.Conn
closeFunc func()
}

// SyscallConn supports a connection check in the MySQL driver by delegating to
// the underlying non-TLS net.Conn.
func (i *instrumentedConn) SyscallConn() (syscall.RawConn, error) {
sconn, ok := i.rawConn.(syscall.Conn)
if !ok {
return nil, errors.New("connection is not a syscall.Conn")
}
return sconn.SyscallConn()
}

// Close delegates to the underylying net.Conn interface and reports the close
// to the provided closeFunc only when Close returns no error.
func (i *instrumentedConn) Close() error {
Expand Down
43 changes: 43 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"
"runtime"
"strings"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -75,6 +76,48 @@ func TestDialerCanConnectToInstance(t *testing.T) {
testSuccessfulDial(t, d, context.Background(), "my-project:my-region:my-instance", WithPublicIP())
}

func TestDialerConnectionSupportsSyscalls(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(
context.Background(),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
stop := mock.StartServerProxy(t, inst)
defer func() {
stop()
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()

d, err := NewDialer(context.Background(),
WithDefaultDialOptions(WithPublicIP()),
WithTokenSource(mock.EmptyTokenSource{}),
)
if err != nil {
t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
}
d.sqladmin = svc

conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
defer conn.Close()
sconn, ok := conn.(syscall.Conn)
if !ok {
t.Fatalf("expected conn to be a syscall.Conn, but it was not")
}
_, err = sconn.SyscallConn()
if err != nil {
t.Fatalf("expected syscall.RawConn, got error: %v", err)
}
}

func TestDialWithAdminAPIErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(context.Background())
Expand Down

0 comments on commit 4b48e3b

Please sign in to comment.