Skip to content

Commit

Permalink
feat: add support for IAM DB Authn (#44)
Browse files Browse the repository at this point in the history
Fixes #16.
  • Loading branch information
enocom committed Aug 31, 2021
1 parent e52afd7 commit 92e28cf
Show file tree
Hide file tree
Showing 11 changed files with 443 additions and 55 deletions.
1 change: 1 addition & 0 deletions .envrc.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export POSTGRES_USER=some-user
export POSTGRES_PASS=some-password
export POSTGRES_DB=some-db-name
export POSTGRES_CONNECTION_NAME=some-project:some-region:some-instance
export POSTGRES_USER_IAM=some-iam-user
26 changes: 25 additions & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -78,6 +80,10 @@ type Dialer struct {
// dialerID uniquely identifies a Dialer. Used for monitoring purposes,
// *only* when a client has configured OpenCensus exporters.
dialerID string

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn. If IAM DB
// Authn is not enabled, iamTokenSource will be nil.
iamTokenSource oauth2.TokenSource
}

// NewDialer creates a new Dialer.
Expand All @@ -92,6 +98,23 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
}
for _, opt := range opts {
opt(cfg)
if cfg.err != nil {
return nil, cfg.err
}
}
// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc, then use the
// default token source.
if cfg.useIAMAuthN && cfg.tokenSource == nil {
ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to create token source: %v", err)
}
cfg.tokenSource = ts
}
// If IAM Authn is not explicitly enabled, remove the token source.
if !cfg.useIAMAuthN {
cfg.tokenSource = nil
}

if cfg.rsaKey == nil {
Expand Down Expand Up @@ -129,6 +152,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
sqladmin: client,
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
iamTokenSource: cfg.tokenSource,
}
return d, nil
}
Expand Down Expand Up @@ -248,7 +272,7 @@ func (d *Dialer) instance(connName string) (*cloudsql.Instance, error) {
if !ok {
// Create a new instance
var err error
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout)
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource)
if err != nil {
d.lock.Unlock()
return nil, err
Expand Down
49 changes: 42 additions & 7 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ func TestDialerCanConnectToInstance(t *testing.T) {
}
}

func TestDialerInstantiationErrors(t *testing.T) {
_, err := NewDialer(context.Background(), WithCredentialsFile("bogus-file.json"))
if err == nil {
t.Fatalf("expected NewDialer to return error, but got none.")
}
}

func TestDialWithAdminAPIErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(context.Background())
Expand Down Expand Up @@ -154,3 +147,45 @@ func TestDialWithConfigurationErrors(t *testing.T) {
t.Fatalf("when TLS handshake fails, want = %T, got = %v", wantErr2, err)
}
}

var fakeServiceAccount = []byte(`{
"type": "service_account",
"project_id": "a-project-id",
"private_key_id": "a-private-key-id",
"private_key": "a-private-key",
"client_email": "email@example.com",
"client_id": "12345",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/email%40example.com"
}`)

func TestIAMAuthn(t *testing.T) {
tcs := []struct {
desc string
opts DialerOption
wantTokenSource bool
}{
{
desc: "When Credentials are provided with IAM Authn ENABLED",
opts: DialerOptions(WithIAMAuthN(), WithCredentialsJSON(fakeServiceAccount)),
wantTokenSource: true,
},
{
desc: "When Credentials are provided with IAM Authn DISABLED",
opts: WithCredentialsJSON(fakeServiceAccount),
wantTokenSource: false,
},
}

for _, tc := range tcs {
d, err := NewDialer(context.Background(), tc.opts)
if err != nil {
t.Errorf("NewDialer failed with error = %v", err)
}
if gotTokenSource := d.iamTokenSource != nil; gotTokenSource != tc.wantTokenSource {
t.Errorf("%v, want = %v, got = %v", tc.desc, tc.wantTokenSource, gotTokenSource)
}
}
}
37 changes: 37 additions & 0 deletions e2e_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
postgresUser = os.Getenv("POSTGRES_USER") // Name of database user.
postgresPass = os.Getenv("POSTGRES_PASS") // Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).
postgresDb = os.Getenv("POSTGRES_DB") // Name of the database to connect to.
postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user.
)

func requirePostgresVars(t *testing.T) {
Expand All @@ -49,6 +50,8 @@ func requirePostgresVars(t *testing.T) {
t.Fatal("'POSTGRES_PASS' env var not set")
case postgresDb:
t.Fatal("'POSTGRES_DB' env var not set")
case postgresUserIAM:
t.Fatal("'POSTGRES_USER_IAM' env var not set")
}
}

Expand Down Expand Up @@ -79,3 +82,37 @@ func TestPgxConnect(t *testing.T) {
}
t.Log(now)
}

func TestConnectWithIAMUser(t *testing.T) {
requirePostgresVars(t)

ctx := context.Background()

// password is intentionally blank
dsn := fmt.Sprintf("user=%s password=\"\" dbname=%s sslmode=disable", postgresUserIAM, postgresDb)
config, err := pgx.ParseConfig(dsn)
if err != nil {
t.Fatalf("failed to parse pgx config: %v", err)
}
d, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithIAMAuthN())
if err != nil {
t.Fatalf("failed to initiate Dialer: %v", err)
}
defer d.Close()
config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) {
return d.Dial(ctx, postgresConnName)
}

conn, connErr := pgx.ConnectConfig(ctx, config)
if connErr != nil {
t.Fatalf("failed to connect: %s", connErr)
}
defer conn.Close(ctx)

var now time.Time
err = conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&now)
if err != nil {
t.Fatalf("QueryRow failed: %s", err)
}
t.Log(now)
}
13 changes: 6 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ module cloud.google.com/go/cloudsqlconn
go 1.15

require (
cloud.google.com/go v0.75.0 // indirect
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.10.1
github.com/pkg/errors v0.9.1 // indirect
go.opencensus.io v0.22.6
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/oauth2 v0.0.0-20210126194326-f9ce19ea3013
go.opencensus.io v0.23.0
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420
golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
google.golang.org/api v0.37.0
google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f
google.golang.org/grpc v1.39.0
google.golang.org/api v0.54.0
google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67
google.golang.org/grpc v1.39.1
)
Loading

0 comments on commit 92e28cf

Please sign in to comment.