diff --git a/internal/alloydb/instance.go b/internal/alloydb/instance.go index 7939bc0d..57468669 100644 --- a/internal/alloydb/instance.go +++ b/internal/alloydb/instance.go @@ -27,12 +27,6 @@ import ( "cloud.google.com/go/alloydbconn/internal/alloydbapi" ) -const ( - // refreshBuffer is the amount of time before a result expires to start a - // new refresh attempt. - refreshBuffer = 12 * time.Hour -) - var ( // Instance URI is in the format: // '/projects//locations//clusters//instances/' @@ -220,6 +214,22 @@ func (i *Instance) result(ctx context.Context) (*refreshOperation, error) { return res, nil } +// refreshDuration returns the duration to wait before starting the next +// refresh. Usually that duration will be half of the time until certificate +// expiration. +func refreshDuration(now, certExpiry time.Time) time.Duration { + d := certExpiry.Sub(now) + if d < time.Hour { + // Something is wrong with the certification, refresh now. + if d < 5*time.Minute { + return 0 + } + // Otherwise, wait five minutes before starting the refresh cycle. + return 5 * time.Minute + } + return d / 2 +} + // scheduleRefresh schedules a refresh operation to be triggered after a given // duration. The returned refreshOperation can be used to either Cancel or Wait // for the operations result. @@ -253,8 +263,8 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { return default: } - nextRefresh := i.cur.result.expiry.Add(-refreshBuffer) - i.next = i.scheduleRefresh(time.Until(nextRefresh)) + t := refreshDuration(time.Now(), i.cur.result.expiry) + i.next = i.scheduleRefresh(t) }) return res } diff --git a/internal/alloydb/instance_test.go b/internal/alloydb/instance_test.go index 9be77c21..4958a6b8 100644 --- a/internal/alloydb/instance_test.go +++ b/internal/alloydb/instance_test.go @@ -214,3 +214,47 @@ func TestClose(t *testing.T) { t.Fatalf("failed to retrieve connect info: %v", err) } } + +func TestRefreshDuration(t *testing.T) { + now := time.Now() + tcs := []struct { + desc string + expiry time.Time + want time.Duration + }{ + { + desc: "when expiration is greater than 1 hour", + expiry: now.Add(4 * time.Hour), + want: 2 * time.Hour, + }, + { + desc: "when expiration is equal to 1 hour", + expiry: now.Add(time.Hour), + want: 30 * time.Minute, + }, + { + desc: "when expiration is less than 1 hour, but greater than 5 minutes", + expiry: now.Add(6 * time.Minute), + want: 5 * time.Minute, + }, + { + desc: "when expiration is less than 5 minutes", + expiry: now.Add(4 * time.Minute), + want: 0, + }, + { + desc: "when expiration is now", + expiry: now, + want: 0, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := refreshDuration(now, tc.expiry) + // round to the second to remove millisecond differences + if got.Round(time.Second) != tc.want { + t.Fatalf("time until refresh: want = %v, got = %v", tc.want, got) + } + }) + } +}