diff --git a/internal/adminapi/backoff_strategy_konnect.go b/internal/adminapi/backoff_strategy_konnect.go index 5b5e81e6d9..f8812d3c00 100644 --- a/internal/adminapi/backoff_strategy_konnect.go +++ b/internal/adminapi/backoff_strategy_konnect.go @@ -26,10 +26,6 @@ type Clock interface { Now() time.Time } -type SystemClock struct{} - -func (SystemClock) Now() time.Time { return time.Now() } - // KonnectBackoffStrategy keeps track of Konnect config push backoffs. // // It takes into account: diff --git a/internal/adminapi/client.go b/internal/adminapi/client.go index bbead3755f..005ca8acf6 100644 --- a/internal/adminapi/client.go +++ b/internal/adminapi/client.go @@ -11,6 +11,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "github.com/kong/kubernetes-ingress-controller/v2/internal/util" + "github.com/kong/kubernetes-ingress-controller/v2/internal/util/clock" ) // Client is a wrapper around raw *kong.Client. It's advised to pass this wrapper across the codebase, and @@ -61,7 +62,7 @@ func NewKonnectClient(c *kong.Client, runtimeGroup string) *KonnectClient { konnectRuntimeGroup: runtimeGroup, pluginSchemaStore: util.NewPluginSchemaStore(c), }, - backoffStrategy: NewKonnectBackoffStrategy(SystemClock{}), + backoffStrategy: NewKonnectBackoffStrategy(clock.System{}), } } diff --git a/internal/license/agent.go b/internal/license/agent.go index daf420bd04..dcfbf0eb29 100644 --- a/internal/license/agent.go +++ b/internal/license/agent.go @@ -11,6 +11,7 @@ import ( "github.com/samber/mo" "github.com/kong/kubernetes-ingress-controller/v2/internal/util" + "github.com/kong/kubernetes-ingress-controller/v2/internal/util/clock" ) const ( @@ -52,6 +53,20 @@ func WithPollingPeriod(regularPollingPeriod time.Duration) AgentOpt { } } +type Ticker interface { + Stop() + Channel() <-chan time.Time + Reset(d time.Duration) +} + +// WithTicker sets the ticker in Agent. This is useful for testing. +// Ticker doesn't define the period, it defines the implementation of ticking. +func WithTicker(t Ticker) AgentOpt { + return func(a *Agent) { + a.ticker = t + } +} + // NewAgent creates a new license agent that retrieves a license from the given url once every given period. func NewAgent( konnectLicenseClient KonnectLicenseClient, @@ -63,6 +78,9 @@ func NewAgent( konnectLicenseClient: konnectLicenseClient, initialPollingPeriod: DefaultInitialPollingPeriod, regularPollingPeriod: DefaultPollingPeriod, + // Note: the ticker defines the implementation of ticking, not the period. + ticker: clock.NewTicker(), + startedCh: make(chan struct{}), } for _, opt := range opts { @@ -78,6 +96,8 @@ type Agent struct { konnectLicenseClient KonnectLicenseClient initialPollingPeriod time.Duration regularPollingPeriod time.Duration + ticker Ticker + startedCh chan struct{} // cachedLicense is the current license retrieved from upstream. It's optional because we may not have retrieved a // license yet. @@ -122,21 +142,28 @@ func (a *Agent) GetLicense() mo.Option[kong.License] { return mo.None[kong.License]() } +// Started returns a channel which will be closed when the Agent has started. +func (a *Agent) Started() <-chan struct{} { + return a.startedCh +} + // runPollingLoop updates the license on a regular period until the context is cancelled. // It will run at a faster period initially, and then switch to the regular period once a license is retrieved. func (a *Agent) runPollingLoop(ctx context.Context) error { - ticker := time.NewTicker(a.resolvePollingPeriod()) - defer ticker.Stop() + a.ticker.Reset(a.initialPollingPeriod) + defer a.ticker.Stop() + ch := a.ticker.Channel() + close(a.startedCh) for { select { - case <-ticker.C: + case <-ch: a.logger.V(util.DebugLevel).Info("retrieving license from external service") if err := a.reconcileLicenseWithKonnect(ctx); err != nil { a.logger.Error(err, "could not reconcile license with Konnect") } // Reset the ticker to run with the expected period which may change after we receive the license. - ticker.Reset(a.resolvePollingPeriod()) + a.ticker.Reset(a.resolvePollingPeriod()) case <-ctx.Done(): a.logger.Info("context done, shutting down license agent") return ctx.Err() diff --git a/internal/license/agent_test.go b/internal/license/agent_test.go index b41135abbd..494f731994 100644 --- a/internal/license/agent_test.go +++ b/internal/license/agent_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" "github.com/kong/kubernetes-ingress-controller/v2/internal/license" + "github.com/kong/kubernetes-ingress-controller/v2/internal/util/clock" + "github.com/kong/kubernetes-ingress-controller/v2/test/mocks" ) type mockKonnectClientClient struct { @@ -19,17 +21,25 @@ type mockKonnectClientClient struct { err error getCalls []time.Time lock sync.RWMutex + clock Clock } -func newMockKonnectLicenseClient(license mo.Option[license.KonnectLicense]) *mockKonnectClientClient { - return &mockKonnectClientClient{konnectLicense: license} +type Clock interface { + Now() time.Time +} + +func newMockKonnectLicenseClient(license mo.Option[license.KonnectLicense], clock Clock) *mockKonnectClientClient { + return &mockKonnectClientClient{ + konnectLicense: license, + clock: clock, + } } func (m *mockKonnectClientClient) Get(context.Context) (mo.Option[license.KonnectLicense], error) { m.lock.Lock() defer m.lock.Unlock() - m.getCalls = append(m.getCalls, time.Now()) + m.getCalls = append(m.getCalls, m.clock.Now()) if m.err != nil { return mo.None[license.KonnectLicense](), m.err @@ -83,44 +93,47 @@ func TestAgent(t *testing.T) { } matchTime = time.Now() return true - }, time.Second, time.Nanosecond) + }, time.Second, time.Millisecond) return matchTime } t.Run("initial license is retrieved", func(t *testing.T) { - upstreamClient := newMockKonnectLicenseClient(mo.Some(expectedLicense)) + upstreamClient := newMockKonnectLicenseClient(mo.Some(expectedLicense), clock.System{}) a := license.NewAgent(upstreamClient, logr.Discard()) - go func() { - err := a.Start(ctx) - require.NoError(t, err) - }() + go a.Start(ctx) //nolint:errcheck expectLicenseToMatchEventually(t, a, expectedLicense.Payload) }) t.Run("initial license retrieval fails and recovers", func(t *testing.T) { - upstreamClient := newMockKonnectLicenseClient(mo.None[license.KonnectLicense]()) + ticker := mocks.NewTicker() + + upstreamClient := newMockKonnectLicenseClient(mo.None[license.KonnectLicense](), ticker) // Return an error on the first call to List() to verify that the agent handles this correctly. upstreamClient.ReturnError(errors.New("something went wrong on a backend")) const ( - // Set the initial polling period to a very short duration to ensure that the agent retries quickly. - initialPollingPeriod = time.Millisecond - regularPollingPeriod = time.Millisecond * 5 - allowedDelta = time.Millisecond + initialPollingPeriod = time.Minute * 3 + regularPollingPeriod = time.Minute * 20 + allowedDelta = time.Second ) + a := license.NewAgent( upstreamClient, logr.Discard(), license.WithInitialPollingPeriod(initialPollingPeriod), license.WithPollingPeriod(regularPollingPeriod), + license.WithTicker(ticker), ) startTime := time.Now() - go func() { - err := a.Start(ctx) - require.NoError(t, err) - }() + go a.Start(ctx) //nolint:errcheck + + select { + case <-a.Started(): + case <-time.After(time.Second): + require.FailNow(t, "timed out waiting for agent to start") + } t.Run("initial polling period is used when no license is retrieved", func(t *testing.T) { require.Eventually(t, func() bool { @@ -128,8 +141,12 @@ func TestAgent(t *testing.T) { }, time.Second, time.Nanosecond, "expected upstream client to be called at least once") firstListCallTime := upstreamClient.GetCalls()[0] - require.WithinDuration(t, startTime.Add(initialPollingPeriod), firstListCallTime, allowedDelta, - "expected first call to List() to happen after the initial polling period") + + require.WithinDuration(t, startTime, firstListCallTime, allowedDelta, + "expected first call to List() to happen immediately after starting the agent") + + // Initial polling period has passed... + ticker.Add(initialPollingPeriod) require.Eventually(t, func() bool { return len(upstreamClient.GetCalls()) >= 2 @@ -145,17 +162,27 @@ func TestAgent(t *testing.T) { t.Run("regular polling period is used after the initial license is retrieved", func(t *testing.T) { // Now return a valid response to ensure that the agent recovers. upstreamClient.ReturnSuccess(mo.Some(expectedLicense)) + + // Regular polling period has passed... + ticker.Add(regularPollingPeriod) + expectLicenseToMatchEventually(t, a, expectedLicense.Payload) listCallsAfterMatchCount := len(upstreamClient.GetCalls()) + + // Regular polling period has passed... + ticker.Add(regularPollingPeriod) + require.Eventually(t, func() bool { return len(upstreamClient.GetCalls()) > listCallsAfterMatchCount - }, time.Second, time.Nanosecond, "expected upstream client to be called at least once after the license is retrieved") + }, time.Second, time.Millisecond, "expected upstream client to be called at least once after the license is retrieved") - listCalls := upstreamClient.GetCalls() - lastListCall := listCalls[len(listCalls)-1] - lastButOneCall := listCalls[len(listCalls)-2] - require.WithinDuration(t, lastButOneCall.Add(regularPollingPeriod), lastListCall, allowedDelta) + require.Eventually(t, func() bool { + listCalls := upstreamClient.GetCalls() + lastListCall := listCalls[len(listCalls)-1] + lastButOneCall := listCalls[len(listCalls)-2] + return lastListCall.Sub(lastButOneCall).Abs() <= allowedDelta + }, time.Second, time.Millisecond) }) t.Run("after the license is retrieved, errors returned from upstream do not override the license", func(t *testing.T) { @@ -163,6 +190,10 @@ func TestAgent(t *testing.T) { // Wait for the call to happen. initialListCalls := len(upstreamClient.GetCalls()) + + // Regular polling period has passed... + ticker.Add(regularPollingPeriod) + require.Eventually(t, func() bool { return len(upstreamClient.GetCalls()) > initialListCalls }, time.Second, time.Nanosecond) @@ -179,6 +210,9 @@ func TestAgent(t *testing.T) { // Wait for the call to happen. initialListCalls := len(upstreamClient.GetCalls()) + // Regular polling period has passed... + ticker.Add(regularPollingPeriod) + require.Eventually(t, func() bool { return len(upstreamClient.GetCalls()) > initialListCalls }, time.Second, time.Nanosecond) @@ -192,6 +226,10 @@ func TestAgent(t *testing.T) { Payload: "new-license", UpdatedAt: expectedLicense.UpdatedAt.Add(time.Second), })) + + // Regular polling period has passed... + ticker.Add(regularPollingPeriod) + expectLicenseToMatchEventually(t, a, "new-license") }) }) diff --git a/internal/util/clock/clock.go b/internal/util/clock/clock.go new file mode 100644 index 0000000000..8d56a65d6c --- /dev/null +++ b/internal/util/clock/clock.go @@ -0,0 +1,7 @@ +package clock + +import "time" + +type System struct{} + +func (System) Now() time.Time { return time.Now() } diff --git a/internal/util/clock/ticker.go b/internal/util/clock/ticker.go new file mode 100644 index 0000000000..1faa4d2c5e --- /dev/null +++ b/internal/util/clock/ticker.go @@ -0,0 +1,34 @@ +package clock + +import "time" + +const ( + // This is irrelevant for the ticker, but we need to pass something to NewTicker. + // The reason for this is that the ticker is used in the license agent, which + // uses a non trivial logic to determine the polling period based on the state + // of license retrieval. + // This might be changed in the future if it doesn't fit the future needs. + initialTickerDuration = 1000 * time.Hour +) + +func NewTicker() *TimeTicker { + return &TimeTicker{ + ticker: time.NewTicker(initialTickerDuration), + } +} + +type TimeTicker struct { + ticker *time.Ticker +} + +func (t *TimeTicker) Stop() { + t.ticker.Stop() +} + +func (t *TimeTicker) Channel() <-chan time.Time { + return t.ticker.C +} + +func (t *TimeTicker) Reset(d time.Duration) { + t.ticker.Reset(d) +} diff --git a/test/mocks/ticker.go b/test/mocks/ticker.go new file mode 100644 index 0000000000..bee354d8fd --- /dev/null +++ b/test/mocks/ticker.go @@ -0,0 +1,90 @@ +package mocks + +import ( + "sync" + "time" +) + +const ( + // This is irrelevant for the ticker, but we need to pass something to NewTicker. + // The reason for this is that the ticker is used in the license agent, which + // uses a non trivial logic to determine the polling period based on the state + // of license retrieval. + // This might be changed in the future if it doesn't fit the future needs. + initialTickerDuration = 1000 * time.Hour +) + +func NewTicker() *Ticker { + now := time.Now() + + ticker := &Ticker{ + sigClose: make(chan struct{}), + d: initialTickerDuration, + ch: make(chan time.Time, 1), + time: now, + lastTick: now, + } + + return ticker +} + +type Ticker struct { + lock sync.RWMutex + sigClose chan struct{} + d time.Duration + ch chan time.Time + time time.Time + lastTick time.Time +} + +func (t *Ticker) Stop() { + close(t.sigClose) +} + +func (t *Ticker) Channel() <-chan time.Time { + return t.ch +} + +func (t *Ticker) Now() time.Time { + t.lock.RLock() + defer t.lock.RUnlock() + return t.time +} + +func (t *Ticker) Reset(d time.Duration) { + select { + case <-t.sigClose: + return + default: + } + + now := time.Now() + + t.lock.Lock() + defer t.lock.Unlock() + + t.lastTick = now + t.time = now + t.d = d +} + +func (t *Ticker) Add(d time.Duration) { + select { + case <-t.sigClose: + return + default: + } + + t.lock.Lock() + defer t.lock.Unlock() + + t.time = t.time.Add(d) + + if t.time.Compare(t.lastTick.Add(t.d)) >= 0 { + select { + case <-t.sigClose: + case t.ch <- t.time: + } + t.lastTick = t.time + } +} diff --git a/test/mocks/ticker_test.go b/test/mocks/ticker_test.go new file mode 100644 index 0000000000..346cad4c99 --- /dev/null +++ b/test/mocks/ticker_test.go @@ -0,0 +1,89 @@ +package mocks + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTicker(t *testing.T) { + t.Run("basic", func(t *testing.T) { + ticker := NewTicker() + ch := ticker.Channel() + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + + ticker.Reset(time.Hour) + + t.Log("adding second should not tick when ticker has an interval of 1 hour") + ticker.Add(time.Second) + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + + t.Log("adding 40 minutes should not tick when ticker has an interval of 1 hour") + ticker.Add(40 * time.Minute) + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + + t.Log("adding 40 minutes should tick when 40 minutes already passed and ticker has an interval of 1 hour") + ticker.Add(40 * time.Minute) + select { + case <-ch: + case <-time.After(time.Second): + require.FailNow(t, "expected a tick to happen but it didn't") + } + }) + + t.Run("Reset", func(t *testing.T) { + ticker := NewTicker() + ch := ticker.Channel() + + t.Log("reseting ticker to 3 hour interval") + ticker.Reset(3 * time.Hour) + t.Log("adding second should not tick when ticker has an interval of 3 hours") + ticker.Add(time.Second) + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + + t.Log("adding an hour should not tick when ticker has an interval of 3 hour") + ticker.Add(time.Hour) + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + + t.Log("adding 2 hours should tick when ticker has an interval of 3 hour") + ticker.Add(2 * time.Hour) + select { + case <-ch: + case <-time.After(time.Second): + require.FailNow(t, "expected a tick to happen but it didn't") + } + }) + + t.Run("stop", func(t *testing.T) { + ticker := NewTicker() + ch := ticker.Channel() + ticker.Stop() + + select { + case <-ch: + require.FailNow(t, "unexpected tick") + default: + } + }) +}