Skip to content

Commit

Permalink
fix: remove duplicate refresh operations (#806)
Browse files Browse the repository at this point in the history
When callers (like the Cloud SQL Proxy) warmup the background refresh
with EngineVersion, the initial refresh operation starts with IAM Authn
disabled. Then when a caller connects with IAM authentication, the
existing refresh is invalidated (because it doesn't include the client's
OAuth2 token). In effect, two refresh operations are completed when the
dialer could have just run one initially.

This commit ensures calls to EngineVersion respect the dialer's global
IAM authentication setting. If IAM authentication is enabled at the
dialer level, then EngineVersion will ensure the refresh operation uses
IAM authentication. And only one refresh operation occurs between warmup
and first connection.

In cases where a dialer is initialized without IAM authentication, but
then a call to dial requests IAM authentication, a second refresh is
unavoidable. This seems to be an uncommon enough use case that it is an
acceptable tradeoff given how IAM authentication is tightly coupled with
the client certificate refresh.

Separately, when Cloud SQL Proxy invocations start the Proxy with the
--token flag in combination with IAM authentication, the underlying
token does not have a corresponding refresh token and so cannot be
refreshed. As a result, when the double calls occur (as described
above), there is a third refresh attempt started because the token has a
missing expiration field (there is only AccessToken, no RefreshToken, or
Expiry).

The dialer sees the missing expiration as an expired client certificate
and immediately starts a new refresh. But because the dialer has already
consumed two attempts, the rate limiter (2 initial attempts, then
30s/attempt) forces the client to wait 30s before connecting.

This commit ensures that situation will not happen by using the correct
client certificate expiration and not the invalid token expiration. For
cases where the Cloud SQL Proxy configures the dialer with the --token
flag (and no refresh token), the dialer will always default to using the
client certificate's expiration. This means the refresh cycle will fail
once the token expires.

Fixes #771
  • Loading branch information
enocom authored May 21, 2024
1 parent 1ee46e8 commit beb3605
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
if err != nil {
return "", err
}
i := d.connectionInfoCache(ctx, cn, nil)
i := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN)
ci, err := i.ConnectionInfo(ctx)
if err != nil {
return "", err
Expand Down
49 changes: 44 additions & 5 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,38 @@ func TestDialerEngineVersion(t *testing.T) {
}
}

// When Auto IAM AuthN is enabled, EngineVersion should warm the cache with a
// client certificate with Auto IAM AuthN enabled.
func TestEngineVersionAvoidsDuplicateRefreshWithIAMAuthN(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
)
d := setupDialer(t, setupConfig{
testInstance: inst,
dialerOptions: []Option{
WithIAMAuthN(), WithIAMAuthNTokenSources(
mock.EmptyTokenSource{},
mock.EmptyTokenSource{},
),
},
reqs: []*mock.Request{
// There should only be two API requests
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
})

_, err := d.EngineVersion(context.Background(), inst.String())
if err != nil {
t.Fatal(err)
}

testSuccessfulDial(
context.Background(), t, d,
inst.String(),
)
}

func TestDialerUserAgent(t *testing.T) {
data, err := os.ReadFile("version.txt")
if err != nil {
Expand All @@ -420,7 +452,7 @@ func TestWarmup(t *testing.T) {
expectedCalls []*mock.Request
}{
{
desc: "warmup and dial are the same",
desc: "Warmup and Dial both use IAM AuthN",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(true)},
expectedCalls: []*mock.Request{
Expand All @@ -429,7 +461,7 @@ func TestWarmup(t *testing.T) {
},
},
{
desc: "warmup and dial are different",
desc: "Warmup uses IAM Authn, Dial does not",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(false)},
expectedCalls: []*mock.Request{
Expand All @@ -438,19 +470,26 @@ func TestWarmup(t *testing.T) {
},
},
{
desc: "warmup and default dial are different",
desc: "Warmup uses IAM AuthN, Dial uses global setting",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
d := setupDialer(t, setupConfig{
dialerOptions: []Option{
WithIAMAuthN(),
WithIAMAuthNTokenSources(
mock.EmptyTokenSource{},
mock.EmptyTokenSource{},
),
},
testInstance: inst,
reqs: test.expectedCalls,
})
Expand Down
7 changes: 6 additions & 1 deletion internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,13 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
nil,
)
} else {
var useIAMAuthN bool
i.mu.Lock()
useIAMAuthN = i.useIAMAuthNDial
i.mu.Unlock()
r.result, r.err = i.r.ConnectionInfo(
ctx, i.connName, i.key, i.useIAMAuthNDial)
ctx, i.connName, i.key, useIAMAuthN,
)
}
switch r.err {
case nil:
Expand Down
19 changes: 15 additions & 4 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,23 @@ func fetchMetadata(
return m, nil
}

var expired = time.Time{}.Add(1)

// canRefresh determines if the provided token was refreshed or if it still has
// the sentinel expiration, which means the token was provided without a
// refresh token (as with the Cloud SQL Proxy's --token flag) and therefore
// cannot be refreshed.
func canRefresh(t *oauth2.Token) bool {
return t.Expiry.Unix() != expired.Unix()
}

// refreshToken will retrieve a new token, only if a refresh token is present.
func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) {
expiredToken := &oauth2.Token{
AccessToken: tok.AccessToken,
TokenType: tok.TokenType,
RefreshToken: tok.RefreshToken,
Expiry: time.Time{}.Add(1), // Expired
Expiry: expired,
}
return oauth2.ReuseTokenSource(expiredToken, ts).Token()
}
Expand Down Expand Up @@ -217,9 +228,9 @@ func fetchEphemeralCert(
)
}
if ts != nil {
// Adjust the certificate's expiration to be the earliest of the token's
// expiration or the certificate's expiration.
if tok.Expiry.Before(clientCert.NotAfter) {
// Adjust the certificate's expiration to be the earliest of
// the token's expiration or the certificate's expiration.
if canRefresh(tok) && tok.Expiry.Before(clientCert.NotAfter) {
clientCert.NotAfter = tok.Expiry
}
}
Expand Down
35 changes: 35 additions & 0 deletions internal/cloudsql/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,41 @@ func TestRefresh(t *testing.T) {
t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration)
}
}

// If a caller has provided a static token source that cannot be refreshed
// (e.g., when the Cloud SQL Proxy is invokved with --token), then the
// refresher cannot determine the token's expiration (without additional API
// calls), and so the refresher should use the certificate's expiration instead
// of the token's expiration which is otherwise unset.
func TestRefreshWithStaticTokenSource(t *testing.T) {
cn := testInstanceConnName()
inst := mock.NewFakeCSQLInstance(
cn.Project(), cn.Region(), cn.Name(),
)
client, cleanup, err := mock.NewSQLAdminService(
context.Background(),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to create test SQL admin service: %s", err)
}
t.Cleanup(func() { _ = cleanup() })

ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"})
r := newRefresher(nullLogger{}, client, ts, testDialerID)
ci, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
if err != nil {
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
}
if !ci.Expiration.After(time.Now()) {
t.Fatalf(
"Connection info expiration should be in the future, got = %v",
ci.Expiration,
)
}
}

func TestRefreshRetries50xResponses(t *testing.T) {
cn := testInstanceConnName()
inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
Expand Down

0 comments on commit beb3605

Please sign in to comment.