Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for IAM DB Authn #44

Merged
merged 8 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .envrc.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export POSTGRES_USER=some-user
export POSTGRES_PASS=some-password
export POSTGRES_DB=some-db-name
export POSTGRES_CONNECTION_NAME=some-project:some-region:some-instance
export POSTGRES_USER_IAM=some-iam-user
26 changes: 25 additions & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -78,6 +80,10 @@ type Dialer struct {
// dialerID uniquely identifies a Dialer. Used for monitoring purposes,
// *only* when a client has configured OpenCensus exporters.
dialerID string

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn. If IAM DB
// Authn is not enabled, iamTokenSource is unused.
enocom marked this conversation as resolved.
Show resolved Hide resolved
iamTokenSource oauth2.TokenSource
}

// NewDialer creates a new Dialer.
Expand All @@ -93,6 +99,23 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
for _, opt := range opts {
opt(cfg)
}
if cfg.err != nil {
return nil, cfg.err
}
enocom marked this conversation as resolved.
Show resolved Hide resolved
// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc, then use the
// default token source.
if cfg.useIAMAuth && cfg.tokenSource == nil {
ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to create token source: %v", err)
}
cfg.tokenSource = ts
}
// If IAM Authn is not explicitly enabled, remove the token source.
if !cfg.useIAMAuth {
cfg.tokenSource = nil
}

if cfg.rsaKey == nil {
key, err := getDefaultKeys()
Expand Down Expand Up @@ -129,6 +152,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
sqladmin: client,
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
iamTokenSource: cfg.tokenSource,
}
return d, nil
}
Expand Down Expand Up @@ -248,7 +272,7 @@ func (d *Dialer) instance(connName string) (*cloudsql.Instance, error) {
if !ok {
// Create a new instance
var err error
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout)
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource)
if err != nil {
d.lock.Unlock()
return nil, err
Expand Down
49 changes: 42 additions & 7 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ func TestDialerCanConnectToInstance(t *testing.T) {
}
}

func TestDialerInstantiationErrors(t *testing.T) {
_, err := NewDialer(context.Background(), WithCredentialsFile("bogus-file.json"))
if err == nil {
t.Fatalf("expected NewDialer to return error, but got none.")
}
}

func TestDialWithAdminAPIErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(context.Background())
Expand Down Expand Up @@ -154,3 +147,45 @@ func TestDialWithConfigurationErrors(t *testing.T) {
t.Fatalf("when TLS handshake fails, want = %T, got = %v", wantErr2, err)
}
}

var fakeServiceAccount = []byte(`{
"type": "service_account",
"project_id": "a-project-id",
"private_key_id": "a-private-key-id",
"private_key": "a-private-key",
"client_email": "email@example.com",
"client_id": "12345",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/email%40example.com"
}`)

func TestIAMAuthn(t *testing.T) {
tcs := []struct {
desc string
opts DialerOption
wantTokenSource bool
}{
{
desc: "When Credentials are provided with IAM Authn ENABLED",
opts: DialerOptions(WithIAMAuthN(), WithCredentialsJSON(fakeServiceAccount)),
wantTokenSource: true,
},
{
desc: "When Credentials are provided with IAM Authn DISABLED",
opts: WithCredentialsJSON(fakeServiceAccount),
wantTokenSource: false,
},
}

for _, tc := range tcs {
d, err := NewDialer(context.Background(), tc.opts)
if err != nil {
t.Errorf("NewDialer failed with error = %v", err)
}
if gotTokenSource := d.iamTokenSource != nil; gotTokenSource != tc.wantTokenSource {
t.Errorf("%v, want = %v, got = %v", tc.desc, tc.wantTokenSource, gotTokenSource)
}
}
}
37 changes: 37 additions & 0 deletions e2e_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
postgresUser = os.Getenv("POSTGRES_USER") // Name of database user.
postgresPass = os.Getenv("POSTGRES_PASS") // Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).
postgresDb = os.Getenv("POSTGRES_DB") // Name of the database to connect to.
postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user.
)

func requirePostgresVars(t *testing.T) {
Expand All @@ -49,6 +50,8 @@ func requirePostgresVars(t *testing.T) {
t.Fatal("'POSTGRES_PASS' env var not set")
case postgresDb:
t.Fatal("'POSTGRES_DB' env var not set")
case postgresUserIAM:
t.Fatal("'POSTGRES_USER_IAM' env var not set")
}
}

Expand Down Expand Up @@ -79,3 +82,37 @@ func TestPgxConnect(t *testing.T) {
}
t.Log(now)
}

func TestConnectWithIAMUser(t *testing.T) {
requirePostgresVars(t)

ctx := context.Background()

// password is intentionally blank
dsn := fmt.Sprintf("user=%s password=\"\" dbname=%s sslmode=disable", postgresUserIAM, postgresDb)
config, err := pgx.ParseConfig(dsn)
if err != nil {
t.Fatalf("failed to parse pgx config: %v", err)
}
d, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithIAMAuthN())
if err != nil {
t.Fatalf("failed to initiate Dialer: %v", err)
}
defer d.Close()
config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) {
return d.Dial(ctx, postgresConnName)
}

conn, connErr := pgx.ConnectConfig(ctx, config)
if connErr != nil {
t.Fatalf("failed to connect: %s", connErr)
}
defer conn.Close(ctx)

var now time.Time
err = conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&now)
if err != nil {
t.Fatalf("QueryRow failed: %s", err)
}
t.Log(now)
}
13 changes: 6 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ module cloud.google.com/go/cloudsqlconn
go 1.15

require (
cloud.google.com/go v0.75.0 // indirect
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.10.1
github.com/pkg/errors v0.9.1 // indirect
go.opencensus.io v0.22.6
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/oauth2 v0.0.0-20210126194326-f9ce19ea3013
go.opencensus.io v0.23.0
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420
golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
google.golang.org/api v0.37.0
google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f
google.golang.org/grpc v1.39.0
google.golang.org/api v0.54.0
google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67
google.golang.org/grpc v1.39.1
)
Loading