diff --git a/internal/adminapi/konnect.go b/internal/adminapi/konnect.go index 62d21203ac..ac8d8b9c1c 100644 --- a/internal/adminapi/konnect.go +++ b/internal/adminapi/konnect.go @@ -20,12 +20,15 @@ type KonnectConfig struct { // TODO https://github.com/Kong/kubernetes-ingress-controller/issues/3922 // ConfigSynchronizationEnabled is the only toggle we had prior to the addition of the license agent. // We likely want to combine these into a single Konnect toggle or piggyback off other Konnect functionality. - ConfigSynchronizationEnabled bool + ConfigSynchronizationEnabled bool + RuntimeGroupID string + Address string + RefreshNodePeriod time.Duration + TLSClient TLSClientConfig + LicenseSynchronizationEnabled bool - RuntimeGroupID string - Address string - RefreshNodePeriod time.Duration - TLSClient TLSClientConfig + InitialLicensePollingPeriod time.Duration + LicensePollingPeriod time.Duration } func NewKongClientForKonnectRuntimeGroup(c KonnectConfig) (*KonnectClient, error) { diff --git a/internal/dataplane/parser/parser.go b/internal/dataplane/parser/parser.go index 2fa3dab72f..13d4603569 100644 --- a/internal/dataplane/parser/parser.go +++ b/internal/dataplane/parser/parser.go @@ -9,6 +9,7 @@ import ( "github.com/blang/semver/v4" "github.com/kong/go-kong/kong" + "github.com/samber/mo" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" discoveryv1 "k8s.io/api/discovery/v1" @@ -124,7 +125,8 @@ func shouldEnableParserExpressionRoutes( // LicenseGetter is an interface for getting the Kong Enterprise license. type LicenseGetter interface { - GetLicense() kong.License + // GetLicense returns an optional license. + GetLicense() mo.Option[kong.License] } // Parser parses Kubernetes objects and configurations into their @@ -239,7 +241,10 @@ func (p *Parser) BuildKongConfig() KongConfigBuildingResult { result.CACertificates = p.getCACerts() if p.licenseGetter != nil { - result.Licenses = append(result.Licenses, p.licenseGetter.GetLicense()) + optionalLicense := p.licenseGetter.GetLicense() + if l, ok := optionalLicense.Get(); ok { + result.Licenses = append(result.Licenses, l) + } } if p.featureFlags.FillIDs { diff --git a/internal/dataplane/parser/parser_test.go b/internal/dataplane/parser/parser_test.go index 5773006efa..67b1fdb89d 100644 --- a/internal/dataplane/parser/parser_test.go +++ b/internal/dataplane/parser/parser_test.go @@ -10,6 +10,7 @@ import ( "github.com/blang/semver/v4" "github.com/kong/go-kong/kong" "github.com/samber/lo" + "github.com/samber/mo" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" @@ -5314,6 +5315,45 @@ func TestNewFeatureFlags(t *testing.T) { } } +type mockLicenseGetter struct { + license mo.Option[kong.License] +} + +func (m *mockLicenseGetter) GetLicense() mo.Option[kong.License] { + return m.license +} + +func TestParser_License(t *testing.T) { + s, _ := store.NewFakeStore(store.FakeObjects{}) + p := mustNewParser(t, s) + + t.Run("no license is populated by default", func(t *testing.T) { + result := p.BuildKongConfig() + require.Empty(t, result.KongState.Licenses) + }) + + t.Run("no license is populated when license getter returns no license", func(t *testing.T) { + p.InjectLicenseGetter(&mockLicenseGetter{}) + result := p.BuildKongConfig() + require.Empty(t, result.KongState.Licenses) + }) + + t.Run("license is populated when license getter returns a license", func(t *testing.T) { + licenseGetterWithLicense := &mockLicenseGetter{ + license: mo.Some(kong.License{ + ID: lo.ToPtr("license-id"), + Payload: lo.ToPtr("license-payload"), + }), + } + p.InjectLicenseGetter(licenseGetterWithLicense) + result := p.BuildKongConfig() + require.Len(t, result.KongState.Licenses, 1) + license := result.KongState.Licenses[0] + require.Equal(t, "license-id", *license.ID) + require.Equal(t, "license-payload", *license.Payload) + }) +} + func mustNewParser(t *testing.T, storer store.Storer) *Parser { const testKongVersion = "3.2.0" diff --git a/internal/konnect/license/client.go b/internal/konnect/license/client.go index be9827ab23..0b404e1412 100644 --- a/internal/konnect/license/client.go +++ b/internal/konnect/license/client.go @@ -4,13 +4,17 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net/http" neturl "net/url" - "strconv" + "time" + + "github.com/samber/mo" "github.com/kong/kubernetes-ingress-controller/v2/internal/adminapi" + "github.com/kong/kubernetes-ingress-controller/v2/internal/license" tlsutil "github.com/kong/kubernetes-ingress-controller/v2/internal/util/tls" ) @@ -53,25 +57,38 @@ func (c *Client) kicLicenseAPIEndpoint() string { return fmt.Sprintf(KICLicenseAPIPathPattern, c.address, c.runtimeGroupID) } -func (c *Client) List(ctx context.Context, pageNumber int) (*ListLicenseResponse, error) { - // TODO this is another case where we have a pseudo-unary object. The page is always 0 in practice, but if we have - // separate functions per entity, we end up with effectively dead code for some - url, _ := neturl.Parse(c.kicLicenseAPIEndpoint()) - if pageNumber != 0 { - q := url.Query() - q.Set("page.number", strconv.Itoa(pageNumber)) - url.RawQuery = q.Encode() +func (c *Client) Get(ctx context.Context) (mo.Option[license.KonnectLicense], error) { + // Make a request to the Konnect license API to list all licenses. + response, err := c.listLicenses(ctx) + if err != nil { + return mo.None[license.KonnectLicense](), fmt.Errorf("failed to list licenses: %w", err) } + + // Convert the response to a KonnectLicense - we're expecting only one license. + l, err := listLicensesResponseToKonnectLicense(response) + if err != nil { + return mo.None[license.KonnectLicense](), fmt.Errorf("failed to convert list licenses response: %w", err) + } + + return l, nil +} + +// isOKStatusCode returns true if the input HTTP status code is 2xx, in [200,300). +func isOKStatusCode(code int) bool { + return code >= 200 && code < 300 +} + +// listLicenses calls the Konnect license API to list all licenses. +func (c *Client) listLicenses(ctx context.Context) (*ListLicenseResponse, error) { + url, _ := neturl.Parse(c.kicLicenseAPIEndpoint()) req, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - httpResp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to get response: %w", err) } - defer httpResp.Body.Close() respBuf, err := io.ReadAll(httpResp.Body) @@ -79,6 +96,10 @@ func (c *Client) List(ctx context.Context, pageNumber int) (*ListLicenseResponse return nil, fmt.Errorf("failed to read response body: %w", err) } + if httpResp.StatusCode == http.StatusNotFound { + // 404 means no license is found which is a valid response. + return nil, nil + } if !isOKStatusCode(httpResp.StatusCode) { return nil, fmt.Errorf("non-success response from Koko: %d, resp body %s", httpResp.StatusCode, string(respBuf)) } @@ -91,7 +112,32 @@ func (c *Client) List(ctx context.Context, pageNumber int) (*ListLicenseResponse return resp, nil } -// isOKStatusCode returns true if the input HTTP status code is 2xx, in [200,300). -func isOKStatusCode(code int) bool { - return code >= 200 && code < 300 +// listLicensesResponseToKonnectLicense converts a ListLicenseResponse to a KonnectLicense. +// It validates the response and returns an error if the response is invalid. +func listLicensesResponseToKonnectLicense(response *ListLicenseResponse) (mo.Option[license.KonnectLicense], error) { + if response == nil { + // If the response is nil, it means no license was found. + return mo.None[license.KonnectLicense](), nil + } + if len(response.Items) == 0 { + return mo.None[license.KonnectLicense](), errors.New("no license item found in response") + } + + // We're expecting only one license. + item := response.Items[0] + if item.License == "" { + return mo.None[license.KonnectLicense](), errors.New("license item has empty license") + } + if item.UpdatedAt == 0 { + return mo.None[license.KonnectLicense](), errors.New("license item has empty updated_at") + } + if item.ID == "" { + return mo.None[license.KonnectLicense](), errors.New("license item has empty id") + } + + return mo.Some(license.KonnectLicense{ + ID: item.ID, + UpdatedAt: time.Unix(int64(item.UpdatedAt), 0), + Payload: item.License, + }), nil } diff --git a/internal/konnect/license/client_test.go b/internal/konnect/license/client_test.go new file mode 100644 index 0000000000..23babde7d4 --- /dev/null +++ b/internal/konnect/license/client_test.go @@ -0,0 +1,162 @@ +package license_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/kong/kubernetes-ingress-controller/v2/internal/adminapi" + "github.com/kong/kubernetes-ingress-controller/v2/internal/konnect/license" +) + +type mockKonnectLicenseServer struct { + response []byte + statusCode int +} + +func newMockKonnectLicenseServer(response []byte, statusCode int) *mockKonnectLicenseServer { + return &mockKonnectLicenseServer{ + response: response, + statusCode: statusCode, + } +} + +func (m *mockKonnectLicenseServer) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(m.statusCode) + _, _ = w.Write(m.response) +} + +func TestLicenseClient(t *testing.T) { + testCases := []struct { + name string + response []byte + status int + assertions func(t *testing.T, c *license.Client) + }{ + { + name: "200 valid response", + response: []byte(`{ + "items": [ + { + "payload": "some-license-content", + "updated_at": 1234567890, + "id": "some-license-id" + } + ] + }`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + licenseOpt, err := c.Get(context.Background()) + require.NoError(t, err) + + l, ok := licenseOpt.Get() + require.True(t, ok) + require.Equal(t, "some-license-content", l.Payload) + require.Equal(t, int64(1234567890), l.UpdatedAt.Unix()) + }, + }, + { + name: "200 but empty response", + response: []byte(`{}`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.ErrorContains(t, err, "no license item found in response") + }, + }, + { + name: "200 but invalid response", + response: []byte(`{invalid-json`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.ErrorContains(t, err, "failed to parse response body") + }, + }, + { + name: "200 but empty license id", + response: []byte(`{ + "items": [ + { + "payload": "some-license-content", + "updated_at": 1234567890, + "id": "" + } + ] + }`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.ErrorContains(t, err, "empty id") + }, + }, + { + name: "200 but empty updated_at", + response: []byte(`{ + "items": [ + { + "payload": "some-license-content", + "updated_at": 0, + "id": "some-license-id" + } + ] + }`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.ErrorContains(t, err, "empty updated_at") + }, + }, + { + name: "200 but empty payload", + response: []byte(`{ + "items": [ + { + "payload": "", + "updated_at": 1234567890, + "id": "some-license-id" + } + ] + }`), + status: http.StatusOK, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.ErrorContains(t, err, "empty license") + }, + }, + { + name: "404 returns empty license with no error", + response: nil, + status: http.StatusNotFound, + assertions: func(t *testing.T, c *license.Client) { + l, err := c.Get(context.Background()) + require.NoError(t, err) + require.False(t, l.IsPresent()) + }, + }, + { + name: "400 returns error", + response: nil, + status: http.StatusBadRequest, + assertions: func(t *testing.T, c *license.Client) { + _, err := c.Get(context.Background()) + require.Error(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := newMockKonnectLicenseServer(tc.response, tc.status) + ts := httptest.NewServer(server) + defer ts.Close() + + c, err := license.NewClient(adminapi.KonnectConfig{Address: ts.URL}) + require.NoError(t, err) + tc.assertions(t, c) + }) + } +} diff --git a/internal/konnect/license/types.go b/internal/konnect/license/types.go index 5a11cc1e27..4697cfe862 100644 --- a/internal/konnect/license/types.go +++ b/internal/konnect/license/types.go @@ -13,6 +13,5 @@ type ListLicenseResponse struct { type Item struct { License string `json:"payload,omitempty"` UpdatedAt uint64 `json:"updated_at,omitempty"` - CreatedAt uint64 `json:"created_at,omitempty"` ID string `json:"id,omitempty"` } diff --git a/internal/license/agent.go b/internal/license/agent.go index 054965a1f8..daf420bd04 100644 --- a/internal/license/agent.go +++ b/internal/license/agent.go @@ -2,51 +2,87 @@ package license import ( "context" - "fmt" "sync" "time" "github.com/go-logr/logr" "github.com/kong/go-kong/kong" + "github.com/samber/lo" + "github.com/samber/mo" - "github.com/kong/kubernetes-ingress-controller/v2/internal/konnect/license" "github.com/kong/kubernetes-ingress-controller/v2/internal/util" ) const ( - // PollingInterval is the interval at which the license agent will poll for license updates. - PollingInterval = time.Hour * 12 + // DefaultPollingPeriod is the period at which the license agent will poll for license updates by default. + DefaultPollingPeriod = time.Hour * 12 + + // DefaultInitialPollingPeriod is the period at which the license agent will poll for a license until it retrieves + // one. + DefaultInitialPollingPeriod = time.Minute // PollingTimeout is the timeout for retrieving a license from upstream. PollingTimeout = time.Minute * 5 ) -type UpstreamClient interface { - List(ctx context.Context, pageNumber int) (*license.ListLicenseResponse, error) +// KonnectLicense is a license retrieved from Konnect. +type KonnectLicense struct { + ID string + Payload string + UpdatedAt time.Time +} + +type KonnectLicenseClient interface { + Get(ctx context.Context) (mo.Option[KonnectLicense], error) +} + +type AgentOpt func(*Agent) + +// WithInitialPollingPeriod sets the initial polling period for the license agent. +func WithInitialPollingPeriod(initialPollingPeriod time.Duration) AgentOpt { + return func(a *Agent) { + a.initialPollingPeriod = initialPollingPeriod + } +} + +// WithPollingPeriod sets the regular polling period for the license agent. +func WithPollingPeriod(regularPollingPeriod time.Duration) AgentOpt { + return func(a *Agent) { + a.regularPollingPeriod = regularPollingPeriod + } } // NewAgent creates a new license agent that retrieves a license from the given url once every given period. func NewAgent( - konnectAPIClient UpstreamClient, + konnectLicenseClient KonnectLicenseClient, logger logr.Logger, + opts ...AgentOpt, ) *Agent { - return &Agent{ - logger: logger, - ticker: time.NewTicker(PollingInterval), - mutex: sync.RWMutex{}, - konnectAPIClient: konnectAPIClient, + a := &Agent{ + logger: logger, + konnectLicenseClient: konnectLicenseClient, + initialPollingPeriod: DefaultInitialPollingPeriod, + regularPollingPeriod: DefaultPollingPeriod, + } + + for _, opt := range opts { + opt(a) } + + return a } // Agent handles retrieving a Kong license and providing it to other KIC subsystems. type Agent struct { - logger logr.Logger - ticker *time.Ticker - mutex sync.RWMutex - konnectAPIClient UpstreamClient - - // license is the current license retrieved from upstream. - license license.Item + logger logr.Logger + konnectLicenseClient KonnectLicenseClient + initialPollingPeriod time.Duration + regularPollingPeriod time.Duration + + // cachedLicense is the current license retrieved from upstream. It's optional because we may not have retrieved a + // license yet. + cachedLicense mo.Option[KonnectLicense] + mutex sync.RWMutex } // NeedLeaderElection indicates if the Agent requires leadership to run. It always returns true. @@ -55,73 +91,109 @@ func (a *Agent) NeedLeaderElection() bool { } // Start starts the Agent. It attempts to pull an initial license from upstream, and then polls for updates on a -// regular interval defined by PollingInterval. +// regular period, either the agent's initialPollingPeriod if it has not yet obtained a license or regularPollingPeriod if it has. func (a *Agent) Start(ctx context.Context) error { a.logger.V(util.DebugLevel).Info("starting license agent") - err := a.updateLicense(ctx) + err := a.reconcileLicenseWithKonnect(ctx) if err != nil { + // If that happens, GetLicense() will return no license until we retrieve a valid one in polling. a.logger.Error(err, "could not retrieve license from upstream") } - return a.run(ctx) + return a.runPollingLoop(ctx) } // GetLicense returns the agent's current license as a go-kong License struct. It omits the origin timestamps, // as Kong will auto-populate these when adding the license to its config database. -func (a *Agent) GetLicense() kong.License { +// It's optional because we may not have retrieved a license yet. +func (a *Agent) GetLicense() mo.Option[kong.License] { a.logger.V(util.DebugLevel).Info("retrieving license from cache") a.mutex.RLock() defer a.mutex.RUnlock() - return kong.License{ - ID: kong.String(a.license.ID), - Payload: kong.String(a.license.License), + + if cachedLicense, ok := a.cachedLicense.Get(); ok { + return mo.Some(kong.License{ + ID: lo.ToPtr(cachedLicense.ID), + Payload: lo.ToPtr(cachedLicense.Payload), + }) } + + return mo.None[kong.License]() } -// run updates the license on a regular interval until the context is cancelled. -func (a *Agent) run(ctx context.Context) error { +// runPollingLoop updates the license on a regular period until the context is cancelled. +// It will run at a faster period initially, and then switch to the regular period once a license is retrieved. +func (a *Agent) runPollingLoop(ctx context.Context) error { + ticker := time.NewTicker(a.resolvePollingPeriod()) + defer ticker.Stop() + for { select { + case <-ticker.C: + a.logger.V(util.DebugLevel).Info("retrieving license from external service") + if err := a.reconcileLicenseWithKonnect(ctx); err != nil { + a.logger.Error(err, "could not reconcile license with Konnect") + } + // Reset the ticker to run with the expected period which may change after we receive the license. + ticker.Reset(a.resolvePollingPeriod()) case <-ctx.Done(): a.logger.Info("context done, shutting down license agent") - a.ticker.Stop() return ctx.Err() - case <-a.ticker.C: - a.logger.V(util.DebugLevel).Info("retrieving license from external service") - if err := a.updateLicense(ctx); err != nil { - a.logger.Error(err, "could not update license") - } } } } -// updateLicense retrievs a license from an outside system. If it successfully retrieves a license, it updates the -// in-memory license cache. -func (a *Agent) updateLicense(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, PollingTimeout) - defer cancel() +func (a *Agent) resolvePollingPeriod() time.Duration { + // If we already have a license, start with the regular polling period (happy path) ... + if a.cachedLicense.IsPresent() { + return a.regularPollingPeriod + } + // ... otherwise, start with the initial polling period which is shorter by default (to get a license faster + // when it appears, e.g. when a user upgrades from Free to Enterprise tier). + return a.initialPollingPeriod +} - // This is an array because it's a Kong entity collection, even though we only expect to have exactly one license. - licenses, err := a.konnectAPIClient.List(ctx, 0) +// reconcileLicenseWithKonnect retrieves a license from upstream and caches it if it is newer than the cached license or there is no cached license. +func (a *Agent) reconcileLicenseWithKonnect(ctx context.Context) error { + retrievedLicenseOpt, err := a.retrieveLicenseFromUpstream(ctx) if err != nil { - return fmt.Errorf("could not retrieve license: %w", err) + return err } - if len(licenses.Items) == 0 { - return fmt.Errorf("received empty license response") + + retrievedLicense, retrievedLicenseOk := retrievedLicenseOpt.Get() + if !retrievedLicenseOk { + // If we get no license from Konnect, we cannot do anything. + a.logger.V(util.DebugLevel).Info("no license found in Konnect") + return nil } - license := licenses.Items[0] - if license.UpdatedAt > a.license.UpdatedAt { - a.logger.V(util.InfoLevel).Info("updating license cache", - "old_updated_at", time.Unix(int64(a.license.UpdatedAt), 0).String(), - "new_updated_at", time.Unix(int64(license.UpdatedAt), 0).String(), + + if a.cachedLicense.IsAbsent() { + a.logger.V(util.InfoLevel).Info("caching initial license retrieved from the upstream", + "updated_at", retrievedLicense.UpdatedAt.String(), + ) + a.updateCache(retrievedLicense) + } else if cachedLicense, ok := a.cachedLicense.Get(); ok && retrievedLicense.UpdatedAt.After(cachedLicense.UpdatedAt) { + a.logger.V(util.InfoLevel).Info("caching license retrieved from the upstream as it is newer than the cached one", + "cached_updated_at", cachedLicense.UpdatedAt.String(), + "retrieved_updated_at", retrievedLicense.UpdatedAt.String(), ) - a.mutex.Lock() - defer a.mutex.Unlock() - a.license = *license + a.updateCache(retrievedLicense) } else { a.logger.V(util.DebugLevel).Info("license cache is up to date") } return nil } + +func (a *Agent) retrieveLicenseFromUpstream(ctx context.Context) (mo.Option[KonnectLicense], error) { + ctx, cancel := context.WithTimeout(ctx, PollingTimeout) + defer cancel() + return a.konnectLicenseClient.Get(ctx) +} + +func (a *Agent) updateCache(license KonnectLicense) { + a.mutex.Lock() + defer a.mutex.Unlock() + a.cachedLicense = mo.Some(license) +} diff --git a/internal/license/agent_test.go b/internal/license/agent_test.go index 9387ff6397..b41135abbd 100644 --- a/internal/license/agent_test.go +++ b/internal/license/agent_test.go @@ -2,49 +2,197 @@ package license_test import ( "context" + "errors" + "sync" "testing" "time" "github.com/go-logr/logr" + "github.com/samber/mo" "github.com/stretchr/testify/require" - konnectLicense "github.com/kong/kubernetes-ingress-controller/v2/internal/konnect/license" "github.com/kong/kubernetes-ingress-controller/v2/internal/license" ) -type mockUpstreamClient struct { - listResponse *konnectLicense.ListLicenseResponse +type mockKonnectClientClient struct { + konnectLicense mo.Option[license.KonnectLicense] + err error + getCalls []time.Time + lock sync.RWMutex } -func (m *mockUpstreamClient) List(context.Context, int) (*konnectLicense.ListLicenseResponse, error) { - return m.listResponse, nil +func newMockKonnectLicenseClient(license mo.Option[license.KonnectLicense]) *mockKonnectClientClient { + return &mockKonnectClientClient{konnectLicense: license} +} + +func (m *mockKonnectClientClient) Get(context.Context) (mo.Option[license.KonnectLicense], error) { + m.lock.Lock() + defer m.lock.Unlock() + + m.getCalls = append(m.getCalls, time.Now()) + + if m.err != nil { + return mo.None[license.KonnectLicense](), m.err + } + return m.konnectLicense, nil +} + +func (m *mockKonnectClientClient) ReturnError(err error) { + m.lock.Lock() + defer m.lock.Unlock() + m.err = err +} + +func (m *mockKonnectClientClient) ReturnSuccess(license mo.Option[license.KonnectLicense]) { + m.lock.Lock() + defer m.lock.Unlock() + m.konnectLicense = license + m.err = nil +} + +func (m *mockKonnectClientClient) GetCalls() []time.Time { + m.lock.RLock() + defer m.lock.RUnlock() + + copied := make([]time.Time, len(m.getCalls)) + copy(copied, m.getCalls) + return copied } func TestAgent(t *testing.T) { - expectedLicense := &konnectLicense.Item{ - License: "test-license", - UpdatedAt: 1234567890, + t.Parallel() + + ctx := context.Background() + + expectedLicense := license.KonnectLicense{ + Payload: "test-license", + UpdatedAt: time.Now(), } - upstreamClient := &mockUpstreamClient{ - listResponse: &konnectLicense.ListLicenseResponse{ - Items: []*konnectLicense.Item{ - expectedLicense, - }, - }, + + expectLicenseToMatchEventually := func(t *testing.T, a *license.Agent, expectedPayload string) time.Time { + var matchTime time.Time + require.Eventually(t, func() bool { + actualLicense, ok := a.GetLicense().Get() + if !ok { + t.Log("license not yet available") + return false + } + if *actualLicense.Payload != expectedPayload { + t.Logf("license mismatch: expected %q, got %q", expectedPayload, *actualLicense.Payload) + return false + } + matchTime = time.Now() + return true + }, time.Second, time.Nanosecond) + return matchTime } - a := license.NewAgent(upstreamClient, logr.Discard()) - ctx := context.Background() - go func() { - err := a.Start(ctx) - require.NoError(t, err) - }() - - require.Eventually(t, func() bool { - actualLicense := a.GetLicense() - if actualLicense.Payload == nil { - return false - } - return *actualLicense.Payload == expectedLicense.License - }, time.Second*5, time.Millisecond) + t.Run("initial license is retrieved", func(t *testing.T) { + upstreamClient := newMockKonnectLicenseClient(mo.Some(expectedLicense)) + a := license.NewAgent(upstreamClient, logr.Discard()) + go func() { + err := a.Start(ctx) + require.NoError(t, err) + }() + expectLicenseToMatchEventually(t, a, expectedLicense.Payload) + }) + + t.Run("initial license retrieval fails and recovers", func(t *testing.T) { + upstreamClient := newMockKonnectLicenseClient(mo.None[license.KonnectLicense]()) + + // Return an error on the first call to List() to verify that the agent handles this correctly. + upstreamClient.ReturnError(errors.New("something went wrong on a backend")) + + const ( + // Set the initial polling period to a very short duration to ensure that the agent retries quickly. + initialPollingPeriod = time.Millisecond + regularPollingPeriod = time.Millisecond * 5 + allowedDelta = time.Millisecond + ) + a := license.NewAgent( + upstreamClient, + logr.Discard(), + license.WithInitialPollingPeriod(initialPollingPeriod), + license.WithPollingPeriod(regularPollingPeriod), + ) + + startTime := time.Now() + go func() { + err := a.Start(ctx) + require.NoError(t, err) + }() + + t.Run("initial polling period is used when no license is retrieved", func(t *testing.T) { + require.Eventually(t, func() bool { + return len(upstreamClient.GetCalls()) >= 1 + }, time.Second, time.Nanosecond, "expected upstream client to be called at least once") + + firstListCallTime := upstreamClient.GetCalls()[0] + require.WithinDuration(t, startTime.Add(initialPollingPeriod), firstListCallTime, allowedDelta, + "expected first call to List() to happen after the initial polling period") + + require.Eventually(t, func() bool { + return len(upstreamClient.GetCalls()) >= 2 + }, time.Second, time.Nanosecond, "expected upstream client to be called at least twice") + + secondListCallTime := upstreamClient.GetCalls()[1] + require.WithinDuration(t, firstListCallTime.Add(initialPollingPeriod), secondListCallTime, allowedDelta, + "expected second call to List() to happen after the initial polling period as no license is retrieved yet") + + require.False(t, a.GetLicense().IsPresent(), "no license should be available due to an error in the upstream client") + }) + + t.Run("regular polling period is used after the initial license is retrieved", func(t *testing.T) { + // Now return a valid response to ensure that the agent recovers. + upstreamClient.ReturnSuccess(mo.Some(expectedLicense)) + expectLicenseToMatchEventually(t, a, expectedLicense.Payload) + + listCallsAfterMatchCount := len(upstreamClient.GetCalls()) + require.Eventually(t, func() bool { + return len(upstreamClient.GetCalls()) > listCallsAfterMatchCount + }, time.Second, time.Nanosecond, "expected upstream client to be called at least once after the license is retrieved") + + listCalls := upstreamClient.GetCalls() + lastListCall := listCalls[len(listCalls)-1] + lastButOneCall := listCalls[len(listCalls)-2] + require.WithinDuration(t, lastButOneCall.Add(regularPollingPeriod), lastListCall, allowedDelta) + }) + + t.Run("after the license is retrieved, errors returned from upstream do not override the license", func(t *testing.T) { + upstreamClient.ReturnError(errors.New("something went wrong on a backend")) + + // Wait for the call to happen. + initialListCalls := len(upstreamClient.GetCalls()) + require.Eventually(t, func() bool { + return len(upstreamClient.GetCalls()) > initialListCalls + }, time.Second, time.Nanosecond) + + // The license should still be available. + require.True(t, a.GetLicense().IsPresent(), "license should be available even if the upstream client returns an error") + }) + + t.Run("license is not updated when the upstream returns a license updated before the cached one", func(t *testing.T) { + upstreamClient.ReturnSuccess(mo.Some(license.KonnectLicense{ + Payload: "new-license", + UpdatedAt: expectedLicense.UpdatedAt.Add(-time.Second), + })) + + // Wait for the call to happen. + initialListCalls := len(upstreamClient.GetCalls()) + require.Eventually(t, func() bool { + return len(upstreamClient.GetCalls()) > initialListCalls + }, time.Second, time.Nanosecond) + + // The cached license should still be available. + expectLicenseToMatchEventually(t, a, expectedLicense.Payload) + }) + + t.Run("license is updated when the upstream returns a license updated after the cached one", func(t *testing.T) { + upstreamClient.ReturnSuccess(mo.Some(license.KonnectLicense{ + Payload: "new-license", + UpdatedAt: expectedLicense.UpdatedAt.Add(time.Second), + })) + expectLicenseToMatchEventually(t, a, "new-license") + }) + }) } diff --git a/internal/manager/config.go b/internal/manager/config.go index 3fddb03621..152bc3cffe 100644 --- a/internal/manager/config.go +++ b/internal/manager/config.go @@ -18,6 +18,7 @@ import ( "github.com/kong/kubernetes-ingress-controller/v2/internal/controllers/gateway" "github.com/kong/kubernetes-ingress-controller/v2/internal/dataplane" "github.com/kong/kubernetes-ingress-controller/v2/internal/konnect" + "github.com/kong/kubernetes-ingress-controller/v2/internal/license" cfgtypes "github.com/kong/kubernetes-ingress-controller/v2/internal/manager/config/types" "github.com/kong/kubernetes-ingress-controller/v2/internal/manager/featuregates" "github.com/kong/kubernetes-ingress-controller/v2/internal/manager/flags" @@ -246,6 +247,8 @@ func (c *Config) FlagSet() *pflag.FlagSet { // Konnect flagSet.BoolVar(&c.Konnect.ConfigSynchronizationEnabled, "konnect-sync-enabled", false, "Enable synchronization of data plane configuration with a Konnect runtime group.") flagSet.BoolVar(&c.Konnect.LicenseSynchronizationEnabled, "konnect-licensing-enabled", false, "Retrieve licenses from Konnect if available. Overrides licenses provided via the environment.") + flagSet.DurationVar(&c.Konnect.InitialLicensePollingPeriod, "konnect-initial-license-polling-period", license.DefaultInitialPollingPeriod, "Polling period to be used before the first license is retrieved.") + flagSet.DurationVar(&c.Konnect.LicensePollingPeriod, "konnect-license-polling-period", license.DefaultPollingPeriod, "Polling period to be used after the first license is retrieved.") flagSet.StringVar(&c.Konnect.RuntimeGroupID, "konnect-runtime-group-id", "", "An ID of a runtime group that is to be synchronized with data plane configuration.") flagSet.StringVar(&c.Konnect.Address, "konnect-address", "https://us.kic.api.konghq.com", "Base address of Konnect API.") flagSet.StringVar(&c.Konnect.TLSClient.Cert, "konnect-tls-client-cert", "", "Konnect TLS client certificate.") diff --git a/internal/manager/run.go b/internal/manager/run.go index dd24521c1e..4cd582f367 100644 --- a/internal/manager/run.go +++ b/internal/manager/run.go @@ -254,7 +254,12 @@ func Run(ctx context.Context, c *Config, diagnostic util.ConfigDumpDiagnostic, d return fmt.Errorf("failed creating konnect client: %w", err) } setupLog.Info("starting license agent") - agent := license.NewAgent(konnectLicenseAPIClient, ctrl.Log.WithName("license-agent")) + agent := license.NewAgent( + konnectLicenseAPIClient, + ctrl.Log.WithName("license-agent"), + license.WithInitialPollingPeriod(c.Konnect.InitialLicensePollingPeriod), + license.WithPollingPeriod(c.Konnect.LicensePollingPeriod), + ) err = mgr.Add(agent) if err != nil { return fmt.Errorf("could not add license agent to manager: %w", err)