Skip to content

Commit

Permalink
fix: schedule refreshes based on result expiration instead of fixed i…
Browse files Browse the repository at this point in the history
…nterval (#21)
  • Loading branch information
kurtisvg committed May 12, 2021
1 parent eb06ae2 commit 65073d0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
69 changes: 51 additions & 18 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import (
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

const (
// refreshBuffer is the amount of time before a result expires to start a new refresh attempt.
refreshBuffer = 5 * time.Minute
)

var (
// Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
// Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT")
Expand Down Expand Up @@ -67,6 +72,7 @@ func parseConnName(cn string) (connName, error) {
type refreshResult struct {
md metadata
tlsCfg *tls.Config
expiry time.Time
err error

// timer that triggers refresh, can be used to cancel.
Expand All @@ -77,17 +83,31 @@ type refreshResult struct {

// Cancel prevents the instanceInfo from starting, if it hasn't already started. Returns true if timer
// was stopped successfully, or false if it has already started.
func (i *refreshResult) Cancel() bool {
return i.timer.Stop()
func (r *refreshResult) Cancel() bool {
return r.timer.Stop()
}

// Wait blocks until the refreshResult attempt is completed.
func (i *refreshResult) Wait(ctx context.Context) error {
func (r *refreshResult) Wait(ctx context.Context) error {
select {
case <-r.ready:
return r.err
case <-ctx.Done():
return ctx.Err()
case <-i.ready:
return i.err
}
}

// IsValid returns true if this result is complete, successful, and is still valid.
func (r *refreshResult) IsValid() bool {
// verify the result has finished running
select {
default:
return false
case <-r.ready:
if r.err != nil || time.Now().After(r.expiry) {
return false
}
return true
}
}

Expand Down Expand Up @@ -166,32 +186,39 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshResult {
res.ready = make(chan struct{})
res.timer = time.AfterFunc(d, func() {
ctx, cancel := context.WithTimeout(context.Background(), i.refreshTimeout)
res.md, res.tlsCfg, res.err = performRefresh(ctx, i.client, i.clientLimiter, i.connName, i.key)
res.md, res.tlsCfg, res.expiry, res.err = performRefresh(ctx, i.client, i.clientLimiter, i.connName, i.key)
cancel()

close(res.ready)
// Once the refresh is complete, update "current" with working result and schedule a new refresh
i.resultGuard.Lock()
defer i.resultGuard.Unlock()
// TODO: only replace cur result if it's not valid
i.cur = res
// if failed, scheduled the next refresh immediately
if res.err != nil {
// TODO: add a backoff on retries
// if failed, scheduled the next refresh immediately
i.next = i.scheduleRefresh(0)
// If the latest result is bad, avoid replacing the used result while it's
// still valid and potentially able to provide successful connections.
// TODO: This means that errors while the current result is still valid are
// surpressed. We should try to surface errors in a more meaningful way.
if !i.cur.IsValid() {
i.cur = res
}
return
}
i.next = i.scheduleRefresh(55 * time.Minute)
// Update the current results, and schedule the next refresh in the future
i.cur = res
nextRefresh := i.cur.expiry.Add(-refreshBuffer)
i.next = i.scheduleRefresh(time.Until(nextRefresh))
})
return res
}

// performRefresh immediately performs a full refresh operation using the Cloud SQL Admin API.
func performRefresh(ctx context.Context, client *sqladmin.Service, l *rate.Limiter, cn connName, k *rsa.PrivateKey) (metadata, *tls.Config, error) {
func performRefresh(ctx context.Context, client *sqladmin.Service, l *rate.Limiter, cn connName, k *rsa.PrivateKey) (metadata, *tls.Config, time.Time, error) {
// avoid refreshing too often to try not to tax the SQL Admin API quotas
err := l.Wait(ctx)
if err != nil {
return metadata{}, nil, fmt.Errorf("refresh was throttled until context expired: %v", err)
return metadata{}, nil, time.Time{}, fmt.Errorf("refresh was throttled until context expired: %w", err)
}

// start async fetching the instance's metadata
Expand Down Expand Up @@ -223,22 +250,28 @@ func performRefresh(ctx context.Context, client *sqladmin.Service, l *rate.Limit
select {
case r := <-mdC:
if r.err != nil {
return md, nil, fmt.Errorf("fetch metadata failed: %w", r.err)
return md, nil, time.Time{}, fmt.Errorf("fetch metadata failed: %w", r.err)
}
md = r.md
case <-ctx.Done():
return md, nil, fmt.Errorf("refresh failed: %w", ctx.Err())
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
var ec tls.Certificate
select {
case r := <-ecC:
if r.err != nil {
return md, nil, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
return md, nil, time.Time{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
}
ec = r.ec
case <-ctx.Done():
return md, nil, fmt.Errorf("refresh failed: %w", ctx.Err())
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

return md, createTLSConfig(cn, md, ec), nil
c := createTLSConfig(cn, md, ec)
// This should never not be the case, but we check to avoid a potential nil-pointer
expiry := time.Time{}
if len(c.Certificates) > 0 {
expiry = c.Certificates[0].Leaf.NotAfter
}
return md, c, expiry, nil
}
2 changes: 1 addition & 1 deletion internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,6 @@ func TestRefreshTimeout(t *testing.T) {

_, _, err = im.ConnectInfo(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("failed to retrieve connect info: %v", err)
t.Fatalf("connect info did not context timeout: %v", err)
}
}

0 comments on commit 65073d0

Please sign in to comment.