Skip to content

Commit

Permalink
fix: error when dialer is misconfigured with token source (#453)
Browse files Browse the repository at this point in the history
When a dialer is configured with a token source and with Auto IAM AuthN,
the dialer will now return an error and require callers to configure the
WithIAMAuthNTokenSources option instead.

This change will prevent OAuth2 tokens with the sqladmin scope from
being included in the ephemeral certificate.
  • Loading branch information
enocom committed Feb 9, 2023
1 parent 33126e9 commit 7b45a7e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 17 deletions.
12 changes: 12 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/rand"
"crypto/rsa"
_ "embed"
"errors"
"fmt"
"net"
"strings"
Expand Down Expand Up @@ -97,6 +98,11 @@ type Dialer struct {
iamTokenSource oauth2.TokenSource
}

var (
errUseTokenSource = errors.New("use WithTokenSource when IAM AuthN is not enabled")
errUseIAMTokenSource = errors.New("use WithIAMAuthNTokenSources instead of WithTokenSource be used when IAM AuthN is enabled")
)

// NewDialer creates a new Dialer.
//
// Initial calls to NewDialer make take longer than normal because generation of an
Expand All @@ -114,6 +120,12 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, cfg.err
}
}
if cfg.useIAMAuthN && cfg.setTokenSource && cfg.iamLoginTokenSource == nil {
return nil, errUseIAMTokenSource
}
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
return nil, errUseTokenSource
}
// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

Expand Down
47 changes: 42 additions & 5 deletions dialer_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright 2021 Google LLC

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//
// https://www.apache.org/licenses/LICENSE-2.0

//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -26,6 +26,7 @@ import (

"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/internal/mock"
"golang.org/x/oauth2"
)

func testSuccessfulDial(ctx context.Context, t *testing.T, d *Dialer, i string, opts ...DialOption) {
Expand Down Expand Up @@ -238,7 +239,10 @@ func TestIAMAuthNErrors(t *testing.T) {
defer stop()

d, err := NewDialer(context.Background(),
WithTokenSource(mock.EmptyTokenSource{}), tc.opts)
WithIAMAuthNTokenSources(
mock.EmptyTokenSource{},
mock.EmptyTokenSource{},
), tc.opts)
if err != nil {
t.Fatalf("NewDialer failed with error = %v", err)
}
Expand Down Expand Up @@ -459,7 +463,11 @@ func TestDialDialerOptsConflicts(t *testing.T) {
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(), WithTokenSource(mock.EmptyTokenSource{}), WithOptions(test.dialerOpts...))
d, err := NewDialer(
context.Background(),
WithIAMAuthNTokenSources(mock.EmptyTokenSource{}, mock.EmptyTokenSource{}),
WithOptions(test.dialerOpts...),
)
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}
Expand All @@ -478,3 +486,32 @@ func TestDialDialerOptsConflicts(t *testing.T) {
})
}
}

func TestTokenSourceWithIAMAuthN(t *testing.T) {
ts := oauth2.StaticTokenSource(&oauth2.Token{})
tcs := []struct {
desc string
opts []Option
wantErr bool
}{
{
desc: "when token source is set with IAM AuthN",
opts: []Option{WithTokenSource(ts), WithIAMAuthN()},
wantErr: true,
},
{
desc: "when IAM AuthN token source is set without IAM AuthN",
opts: []Option{WithIAMAuthNTokenSources(ts, ts)},
wantErr: true,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts...)
gotErr := err != nil
if tc.wantErr != gotErr {
t.Fatalf("err: want = %v, got = %v", tc.wantErr, gotErr)
}
})
}
}
27 changes: 15 additions & 12 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright 2020 Google LLC

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//
// https://www.apache.org/licenses/LICENSE-2.0

//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -34,14 +34,16 @@ import (
type Option func(d *dialerConfig)

type dialerConfig struct {
rsaKey *rsa.PrivateKey
sqladminOpts []apiopt.ClientOption
dialOpts []DialOption
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
refreshTimeout time.Duration
useIAMAuthN bool
iamLoginTokenSource oauth2.TokenSource
useragents []string
rsaKey *rsa.PrivateKey
sqladminOpts []apiopt.ClientOption
dialOpts []DialOption
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
refreshTimeout time.Duration
useIAMAuthN bool
setTokenSource bool
setIAMAuthNTokenSource bool
iamLoginTokenSource oauth2.TokenSource
useragents []string
// err tracks any dialer options that may have failed.
err error
}
Expand Down Expand Up @@ -114,7 +116,7 @@ func WithDefaultDialOptions(opts ...DialOption) Option {
// WithTokenSource should not be used with WithIAMAuthNTokenSources.
func WithTokenSource(s oauth2.TokenSource) Option {
return func(d *dialerConfig) {
d.iamLoginTokenSource = s
d.setTokenSource = true
d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(s))
}
}
Expand All @@ -135,6 +137,7 @@ func WithTokenSource(s oauth2.TokenSource) Option {
// not be used with WithTokenSource.
func WithIAMAuthNTokenSources(apiTS, iamLoginTS oauth2.TokenSource) Option {
return func(d *dialerConfig) {
d.setIAMAuthNTokenSource = true
d.iamLoginTokenSource = iamLoginTS
d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(apiTS))
}
Expand Down

0 comments on commit 7b45a7e

Please sign in to comment.