diff --git a/cns/NetworkContainerContract.go b/cns/NetworkContainerContract.go index 77ba451414..71e66928a8 100644 --- a/cns/NetworkContainerContract.go +++ b/cns/NetworkContainerContract.go @@ -15,6 +15,7 @@ import ( // Container Network Service DNC Contract const ( SetOrchestratorType = "/network/setorchestratortype" + GetHomeAz = "/homeaz" CreateOrUpdateNetworkContainer = "/network/createorupdatenetworkcontainer" DeleteNetworkContainer = "/network/deletenetworkcontainer" PublishNetworkContainer = "/network/publishnetworkcontainer" diff --git a/cns/api.go b/cns/api.go index 1e2b8349db..fd3ef728d9 100644 --- a/cns/api.go +++ b/cns/api.go @@ -65,6 +65,7 @@ type IPConfigurationStatus struct { // Equals compares a subset of the IPConfigurationStatus fields since a direct // DeepEquals or otherwise complete comparison of two IPConfigurationStatus objects // compares internal state details that don't impact their functional equality. +// //nolint:gocritic // it's safer to pass this by value func (i *IPConfigurationStatus) Equals(o IPConfigurationStatus) bool { if i.PodInfo != nil && o.PodInfo != nil { @@ -107,6 +108,7 @@ func (i *IPConfigurationStatus) String() string { // a struct that has public fields for the original struct's private fields, // embed the original struct in an anonymous struct as the alias type, and then // let the default marshaller do its magic. +// //nolint:gocritic // ignore hugeParam it's a value receiver on purpose func (i IPConfigurationStatus) MarshalJSON() ([]byte, error) { type alias IPConfigurationStatus @@ -335,3 +337,13 @@ type NmAgentSupportedApisResponse struct { Response Response SupportedApis []string } + +type HomeAzResponse struct { + IsSupported bool `json:"isSupported"` + HomeAz uint `json:"homeAz"` +} + +type GetHomeAzResponse struct { + Response Response `json:"response"` + HomeAzResponse HomeAzResponse `json:"homeAzResponse"` +} diff --git a/cns/client/client.go b/cns/client/client.go index 8f583abdd3..41b9e7ae56 100644 --- a/cns/client/client.go +++ b/cns/client/client.go @@ -41,6 +41,7 @@ var clientPaths = []string{ cns.NMAgentSupportedAPIs, cns.DeleteNetworkContainer, cns.NetworkContainersURLPath, + cns.GetHomeAz, } type do interface { @@ -811,3 +812,38 @@ func (c *Client) PostAllNetworkContainers(ctx context.Context, createNcRequest c return nil } + +// GetHomeAz gets home AZ of host +func (c *Client) GetHomeAz(ctx context.Context) (*cns.GetHomeAzResponse, error) { + // build the request + u := c.routes[cns.GetHomeAz] + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, errors.Wrap(err, "building http request") + } + + // submit the request + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "sending HTTP request") + } + defer resp.Body.Close() + + // decode the response + var getHomeAzResponse cns.GetHomeAzResponse + err = json.NewDecoder(resp.Body).Decode(&getHomeAzResponse) + if err != nil { + return nil, errors.Wrap(err, "decoding response as JSON") + } + + // if the return code is non-zero, something went wrong and it should be + // surfaced to the caller + if getHomeAzResponse.Response.ReturnCode != 0 { + return nil, &CNSClientError{ + Code: getHomeAzResponse.Response.ReturnCode, + Err: errors.New(getHomeAzResponse.Response.Message), + } + } + + return &getHomeAzResponse, nil +} diff --git a/cns/client/client_test.go b/cns/client/client_test.go index 582c2754b7..887741e4ce 100644 --- a/cns/client/client_test.go +++ b/cns/client/client_test.go @@ -169,7 +169,7 @@ func TestMain(m *testing.M) { logger.InitLogger(logName, 0, 0, tmpLogDir+"/") config := common.ServiceConfig{} - httpRestService, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.NMAgentClientFake{}, nil, nil) + httpRestService, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.NMAgentClientFake{}, nil, nil, nil) svc = httpRestService.(*restserver.HTTPRestService) svc.Name = "cns-test-server" fakeNNC := v1alpha.NodeNetworkConfig{ @@ -2260,3 +2260,66 @@ func TestPostAllNetworkContainers(t *testing.T) { }) } } + +func TestGetHomeAz(t *testing.T) { + emptyRoutes, _ := buildRoutes(defaultBaseURL, clientPaths) + tests := []struct { + name string + shouldErr bool + exp *cns.GetHomeAzResponse + }{ + { + "happy path", + false, + &cns.GetHomeAzResponse{ + Response: cns.Response{ + ReturnCode: 0, + Message: "success", + }, + HomeAzResponse: cns.HomeAzResponse{ + IsSupported: true, + HomeAz: uint(1), + }, + }, + }, + { + "error", + true, + &cns.GetHomeAzResponse{ + Response: cns.Response{ + ReturnCode: types.UnexpectedError, + Message: "unexpected error", + }, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + client := &Client{ + client: &mockdo{ + errToReturn: nil, + objToReturn: test.exp, + httpStatusCodeToReturn: http.StatusOK, + }, + routes: emptyRoutes, + } + + got, err := client.GetHomeAz(context.Background()) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected an error but received none") + } + + if !test.shouldErr && !cmp.Equal(got, test.exp) { + t.Error("received response differs from expectation: diff:", cmp.Diff(got, test.exp)) + } + }) + } +} diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index fdd5fddfe6..23ec3116ad 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -28,5 +28,6 @@ }, "MSISettings": { "ResourceID": "" - } + }, + "PopulateHomeAzCacheRetryIntervalSecs": 15 } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 0b3d4371cb..53da2828e1 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -45,6 +45,7 @@ type CNSConfig struct { CNIConflistScenario string EnableCNIConflistGeneration bool CNIConflistFilepath string + PopulateHomeAzCacheRetryIntervalSecs int } type TelemetrySettings struct { @@ -239,4 +240,8 @@ func SetCNSConfigDefaults(config *CNSConfig) { if config.SyncHostNCTimeoutMs == 0 { config.SyncHostNCTimeoutMs = 500 //nolint:gomnd // default times } + if config.PopulateHomeAzCacheRetryIntervalSecs == 0 { + // set the default PopulateHomeAzCache retry interval to 15 seconds + config.PopulateHomeAzCacheRetryIntervalSecs = 15 + } } diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 41299637e7..786deb36c3 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -207,6 +207,7 @@ func TestSetCNSConfigDefaults(t *testing.T) { KeyVaultSettings: KeyVaultSettings{ RefreshIntervalInHrs: 12, }, + PopulateHomeAzCacheRetryIntervalSecs: 15, }, }, { @@ -229,6 +230,7 @@ func TestSetCNSConfigDefaults(t *testing.T) { KeyVaultSettings: KeyVaultSettings{ RefreshIntervalInHrs: 3, }, + PopulateHomeAzCacheRetryIntervalSecs: 10, }, want: CNSConfig{ ChannelMode: "Other", @@ -248,6 +250,7 @@ func TestSetCNSConfigDefaults(t *testing.T) { KeyVaultSettings: KeyVaultSettings{ RefreshIntervalInHrs: 3, }, + PopulateHomeAzCacheRetryIntervalSecs: 10, }, }, } diff --git a/cns/fakes/nmagentclientfake.go b/cns/fakes/nmagentclientfake.go index d9e16929e6..94babee35e 100644 --- a/cns/fakes/nmagentclientfake.go +++ b/cns/fakes/nmagentclientfake.go @@ -20,28 +20,33 @@ type NMAgentClientFake struct { SupportedAPIsF func(context.Context) ([]string, error) GetNCVersionF func(context.Context, nmagent.NCVersionRequest) (nmagent.NCVersion, error) GetNCVersionListF func(context.Context) (nmagent.NCVersionList, error) + GetHomeAzF func(context.Context) (nmagent.AzResponse, error) } -func (c *NMAgentClientFake) PutNetworkContainer(ctx context.Context, req *nmagent.PutNetworkContainerRequest) error { - return c.PutNetworkContainerF(ctx, req) +func (n *NMAgentClientFake) PutNetworkContainer(ctx context.Context, req *nmagent.PutNetworkContainerRequest) error { + return n.PutNetworkContainerF(ctx, req) } -func (c *NMAgentClientFake) DeleteNetworkContainer(ctx context.Context, req nmagent.DeleteContainerRequest) error { - return c.DeleteNetworkContainerF(ctx, req) +func (n *NMAgentClientFake) DeleteNetworkContainer(ctx context.Context, req nmagent.DeleteContainerRequest) error { + return n.DeleteNetworkContainerF(ctx, req) } -func (c *NMAgentClientFake) JoinNetwork(ctx context.Context, req nmagent.JoinNetworkRequest) error { - return c.JoinNetworkF(ctx, req) +func (n *NMAgentClientFake) JoinNetwork(ctx context.Context, req nmagent.JoinNetworkRequest) error { + return n.JoinNetworkF(ctx, req) } -func (c *NMAgentClientFake) SupportedAPIs(ctx context.Context) ([]string, error) { - return c.SupportedAPIsF(ctx) +func (n *NMAgentClientFake) SupportedAPIs(ctx context.Context) ([]string, error) { + return n.SupportedAPIsF(ctx) } -func (c *NMAgentClientFake) GetNCVersion(ctx context.Context, req nmagent.NCVersionRequest) (nmagent.NCVersion, error) { - return c.GetNCVersionF(ctx, req) +func (n *NMAgentClientFake) GetNCVersion(ctx context.Context, req nmagent.NCVersionRequest) (nmagent.NCVersion, error) { + return n.GetNCVersionF(ctx, req) } -func (c *NMAgentClientFake) GetNCVersionList(ctx context.Context) (nmagent.NCVersionList, error) { - return c.GetNCVersionListF(ctx) +func (n *NMAgentClientFake) GetNCVersionList(ctx context.Context) (nmagent.NCVersionList, error) { + return n.GetNCVersionListF(ctx) +} + +func (n *NMAgentClientFake) GetHomeAz(ctx context.Context) (nmagent.AzResponse, error) { + return n.GetHomeAzF(ctx) } diff --git a/cns/restserver/api.go b/cns/restserver/api.go index cc6539bb2d..b02132bc54 100644 --- a/cns/restserver/api.go +++ b/cns/restserver/api.go @@ -35,7 +35,9 @@ var ( // 3) the ncid parameter // 4) the authentication token parameter // 5) the optional delete path -const ncURLExpectedMatches = 5 +const ( + ncURLExpectedMatches = 5 +) // This file contains implementation of all HTTP APIs which are exposed to external clients. // TODO: break it even further per module (network, nc, etc) like it is done for ipam @@ -764,6 +766,25 @@ func (service *HTTPRestService) setOrchestratorType(w http.ResponseWriter, r *ht logger.Response(service.Name, resp, resp.ReturnCode, err) } +// getHomeAz retrieves home AZ of host +func (service *HTTPRestService) getHomeAz(w http.ResponseWriter, r *http.Request) { + logger.Printf("[Azure CNS] getHomeAz") + logger.Request(service.Name, "getHomeAz", nil) + ctx := r.Context() + + switch r.Method { + case http.MethodGet: + getHomeAzResponse := service.homeAzMonitor.GetHomeAz(ctx) + service.setResponse(w, getHomeAzResponse.Response.ReturnCode, getHomeAzResponse) + default: + returnMessage := "[Azure CNS] Error. getHomeAz did not receive a GET." + returnCode := types.UnsupportedVerb + service.setResponse(w, returnCode, cns.GetHomeAzResponse{ + Response: cns.Response{ReturnCode: returnCode, Message: returnMessage}, + }) + } +} + func (service *HTTPRestService) createOrUpdateNetworkContainer(w http.ResponseWriter, r *http.Request) { logger.Printf("[Azure CNS] createOrUpdateNetworkContainer") diff --git a/cns/restserver/api_test.go b/cns/restserver/api_test.go index 9491517cfa..ebf7d7c638 100644 --- a/cns/restserver/api_test.go +++ b/cns/restserver/api_test.go @@ -977,6 +977,24 @@ func TestNmAgentSupportedApisHandler(t *testing.T) { fmt.Printf("nmAgentSupportedApisHandler Responded with %+v\n", nmAgentSupportedApisResponse) } +// Testing GetHomeAz API handler, return UnsupportedVerb if http method is not supported +func TestGetHomeAz_UnsupportedHttpMethod(t *testing.T) { + req, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, cns.GetHomeAz, http.NoBody) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + var getHomeAzResponse cns.GetHomeAzResponse + err = decodeResponse(w, &getHomeAzResponse) + if err != nil && getHomeAzResponse.Response.ReturnCode != types.UnsupportedVerb { + t.Errorf("GetHomeAz not failing to unsupported http method with response %+v", getHomeAzResponse) + } + logger.Printf("GetHomeAz Responded with %+v\n", getHomeAzResponse) +} + func TestCreateHostNCApipaEndpoint(t *testing.T) { fmt.Println("Test: createHostNCApipaEndpoint") @@ -1359,7 +1377,7 @@ func startService() error { } nmagentClient := &fakes.NMAgentClientFake{} - service, err = NewHTTPRestService(&config, &fakes.WireserverClientFake{}, nmagentClient, nil, nil) + service, err = NewHTTPRestService(&config, &fakes.WireserverClientFake{}, nmagentClient, nil, nil, nil) if err != nil { return err } diff --git a/cns/restserver/homeazmonitor.go b/cns/restserver/homeazmonitor.go new file mode 100644 index 0000000000..81b892d793 --- /dev/null +++ b/cns/restserver/homeazmonitor.go @@ -0,0 +1,166 @@ +package restserver + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/Azure/azure-container-networking/cns" + "github.com/Azure/azure-container-networking/cns/types" + "github.com/Azure/azure-container-networking/log" + "github.com/Azure/azure-container-networking/nmagent" + "github.com/patrickmn/go-cache" + "github.com/pkg/errors" +) + +const ( + GetHomeAzAPIName = "GetHomeAz" + ContextTimeOut = 2 * time.Second + homeAzCacheKey = "HomeAz" +) + +type HomeAzMonitor struct { + nmagentClient + values *cache.Cache + // channel used as signal to end of the goroutine for populating home az cache + closing chan struct{} + cacheRefreshIntervalSecs time.Duration +} + +// NewHomeAzMonitor creates a new HomeAzMonitor object +func NewHomeAzMonitor(client nmagentClient, cacheRefreshIntervalSecs time.Duration) *HomeAzMonitor { + return &HomeAzMonitor{ + nmagentClient: client, + cacheRefreshIntervalSecs: cacheRefreshIntervalSecs, + values: cache.New(cache.NoExpiration, cache.NoExpiration), + closing: make(chan struct{}), + } +} + +// GetHomeAz returns home az cache value directly +func (h *HomeAzMonitor) GetHomeAz(_ context.Context) cns.GetHomeAzResponse { + return h.readCacheValue() +} + +// updateCacheValue updates home az cache value +func (h *HomeAzMonitor) updateCacheValue(resp cns.GetHomeAzResponse) { + h.values.Set(homeAzCacheKey, resp, cache.NoExpiration) +} + +// readCacheValue reads home az cache value +func (h *HomeAzMonitor) readCacheValue() cns.GetHomeAzResponse { + cachedResp, found := h.values.Get(homeAzCacheKey) + if !found { + return cns.GetHomeAzResponse{Response: cns.Response{ + ReturnCode: types.UnexpectedError, + Message: "HomeAz Cache is unavailable", + }, HomeAzResponse: cns.HomeAzResponse{IsSupported: false}} + } + return cachedResp.(cns.GetHomeAzResponse) +} + +// Start starts a new thread to refresh home az cache +func (h *HomeAzMonitor) Start() { + go h.refresh() +} + +// Stop ends the refresh thread +func (h *HomeAzMonitor) Stop() { + close(h.closing) +} + +// refresh periodically pulls home az from nmagent +func (h *HomeAzMonitor) refresh() { + // Ticker will not tick right away, so proactively make a call here to achieve that + ctx, cancel := context.WithTimeout(context.Background(), ContextTimeOut) + h.Populate(ctx) + cancel() + + ticker := time.NewTicker(h.cacheRefreshIntervalSecs) + defer ticker.Stop() + for { + select { + case <-h.closing: + return + case <-ticker.C: + ctx, cancel = context.WithTimeout(context.Background(), ContextTimeOut) + h.Populate(ctx) + cancel() + } + } +} + +// Populate makes call to nmagent to retrieve home az if getHomeAz api is supported by nmagent +func (h *HomeAzMonitor) Populate(ctx context.Context) { + supportedApis, err := h.SupportedAPIs(ctx) + if err != nil { + returnMessage := fmt.Sprintf("[HomeAzMonitor] failed to query nmagent's supported apis, %v", err) + returnCode := types.NmAgentSupportedApisError + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: false}) + return + } + // check if getHomeAz api is supported by nmagent + if !isAPISupportedByNMAgent(supportedApis, GetHomeAzAPIName) { + returnMessage := fmt.Sprintf("[HomeAzMonitor] nmagent does not support %s api.", GetHomeAzAPIName) + returnCode := types.Success + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: false}) + return + } + + // calling NMAgent to get home AZ + azResponse, err := h.nmagentClient.GetHomeAz(ctx) + if err != nil { + apiError := nmagent.Error{} + if ok := errors.As(err, &apiError); ok { + switch apiError.StatusCode() { + case http.StatusInternalServerError: + returnMessage := fmt.Sprintf("[HomeAzMonitor] nmagent internal server error, %v", err) + returnCode := types.NmAgentInternalServerError + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: true}) + return + + case http.StatusUnauthorized: + returnMessage := fmt.Sprintf("[HomeAzMonitor] failed to authenticate with OwningServiceInstanceId, %v", err) + returnCode := types.StatusUnauthorized + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: true}) + return + + default: + returnMessage := fmt.Sprintf("[HomeAzMonitor] failed with StatusCode: %d", apiError.StatusCode()) + returnCode := types.UnexpectedError + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: true}) + return + } + } + returnMessage := fmt.Sprintf("[HomeAzMonitor] failed with Error. %v", err) + returnCode := types.UnexpectedError + h.update(returnCode, returnMessage, cns.HomeAzResponse{IsSupported: true}) + return + } + + h.update(types.Success, "Get Home Az succeeded", cns.HomeAzResponse{IsSupported: true, HomeAz: azResponse.HomeAz}) +} + +// update constructs a GetHomeAzResponse entity and update its cache +func (h *HomeAzMonitor) update(code types.ResponseCode, msg string, homeAzResponse cns.HomeAzResponse) { + log.Debugf(msg) + resp := cns.GetHomeAzResponse{ + Response: cns.Response{ + ReturnCode: code, + Message: msg, + }, + HomeAzResponse: homeAzResponse, + } + h.updateCacheValue(resp) +} + +// isAPISupportedByNMAgent checks if a nmagent client api slice contains a given api +func isAPISupportedByNMAgent(apis []string, api string) bool { + for _, supportedAPI := range apis { + if supportedAPI == api { + return true + } + } + return false +} diff --git a/cns/restserver/homeazmonitor_test.go b/cns/restserver/homeazmonitor_test.go new file mode 100644 index 0000000000..1b699c1ea8 --- /dev/null +++ b/cns/restserver/homeazmonitor_test.go @@ -0,0 +1,89 @@ +package restserver + +import ( + "context" + "testing" + "time" + + "github.com/Azure/azure-container-networking/cns" + "github.com/Azure/azure-container-networking/cns/fakes" + "github.com/Azure/azure-container-networking/cns/types" + "github.com/Azure/azure-container-networking/nmagent" + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" +) + +// TestHomeAzMonitor makes sure the HomeAzMonitor works properly in caching home az +func TestHomeAzMonitor(t *testing.T) { + tests := []struct { + name string + client *fakes.NMAgentClientFake + homeAzExp cns.HomeAzResponse + shouldErr bool + }{ + { + "happy path", + &fakes.NMAgentClientFake{ + SupportedAPIsF: func(ctx context.Context) ([]string, error) { + return []string{"GetHomeAz"}, nil + }, + GetHomeAzF: func(ctx context.Context) (nmagent.AzResponse, error) { + return nmagent.AzResponse{HomeAz: uint(1)}, nil + }, + }, + cns.HomeAzResponse{IsSupported: true, HomeAz: uint(1)}, + false, + }, + { + "getHomeAz is not supported in nmagent", + &fakes.NMAgentClientFake{ + SupportedAPIsF: func(ctx context.Context) ([]string, error) { + return []string{"dummy"}, nil + }, + GetHomeAzF: func(ctx context.Context) (nmagent.AzResponse, error) { + return nmagent.AzResponse{}, nil + }, + }, + cns.HomeAzResponse{}, + false, + }, + { + "api supported but got unexpected errors", + &fakes.NMAgentClientFake{ + SupportedAPIsF: func(ctx context.Context) ([]string, error) { + return []string{GetHomeAzAPIName}, nil + }, + GetHomeAzF: func(ctx context.Context) (nmagent.AzResponse, error) { + return nmagent.AzResponse{}, errors.New("unexpected error") + }, + }, + cns.HomeAzResponse{IsSupported: true}, + true, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + homeAzMonitor := NewHomeAzMonitor(test.client, time.Second) + homeAzMonitor.Populate(context.TODO()) + + getHomeAzResponse := homeAzMonitor.GetHomeAz(context.TODO()) + // check the homeAz cache value + if !cmp.Equal(getHomeAzResponse.HomeAzResponse, test.homeAzExp) { + t.Error("homeAz cache differs from expectation: diff:", cmp.Diff(getHomeAzResponse.HomeAzResponse, test.homeAzExp)) + } + + // check returnCode for error + if getHomeAzResponse.Response.ReturnCode != types.Success && !test.shouldErr { + t.Fatal("unexpected error: ", getHomeAzResponse.Response.Message) + } + if getHomeAzResponse.Response.ReturnCode == types.Success && test.shouldErr { + t.Fatal("expected error but received none") + } + t.Cleanup(func() { + homeAzMonitor.Stop() + }) + }) + } +} diff --git a/cns/restserver/ipam_test.go b/cns/restserver/ipam_test.go index 432d1b0618..5e927b7855 100644 --- a/cns/restserver/ipam_test.go +++ b/cns/restserver/ipam_test.go @@ -40,7 +40,7 @@ var ( func getTestService() *HTTPRestService { var config common.ServiceConfig - httpsvc, _ := NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.NMAgentClientFake{}, store.NewMockStore(""), nil) + httpsvc, _ := NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.NMAgentClientFake{}, store.NewMockStore(""), nil, nil) svc = httpsvc.(*HTTPRestService) svc.IPAMPoolMonitor = &fakes.MonitorFake{} setOrchestratorTypeInternal(cns.KubernetesCRD) diff --git a/cns/restserver/restserver.go b/cns/restserver/restserver.go index ac4aff9e6e..0fdcbcbaf6 100644 --- a/cns/restserver/restserver.go +++ b/cns/restserver/restserver.go @@ -46,6 +46,7 @@ type nmagentClient interface { SupportedAPIs(context.Context) ([]string, error) GetNCVersion(context.Context, nma.NCVersionRequest) (nma.NCVersion, error) GetNCVersionList(context.Context) (nma.NCVersionList, error) + GetHomeAz(context.Context) (nma.AzResponse, error) } // HTTPRestService represents http listener for CNS - Container Networking Service. @@ -55,6 +56,7 @@ type HTTPRestService struct { wscli interfaceGetter ipamClient *ipamclient.IpamClient nma nmagentClient + homeAzMonitor *HomeAzMonitor networkContainer *networkcontainers.NetworkContainers PodIPIDByPodInterfaceKey map[string]string // PodInterfaceId is key and value is Pod IP (SecondaryIP) uuid. PodIPConfigState map[string]cns.IPConfigurationStatus // Secondary IP ID(uuid) is key @@ -146,7 +148,7 @@ type networkInfo struct { // NewHTTPRestService creates a new HTTP Service object. func NewHTTPRestService(config *common.ServiceConfig, wscli interfaceGetter, nmagentClient nmagentClient, - endpointStateStore store.KeyValueStore, gen CNIConflistGenerator, + endpointStateStore store.KeyValueStore, gen CNIConflistGenerator, homeAzMonitor *HomeAzMonitor, ) (cns.HTTPService, error) { service, err := cns.NewService(config.Name, config.Version, config.ChannelMode, config.Store) if err != nil { @@ -202,6 +204,7 @@ func NewHTTPRestService(config *common.ServiceConfig, wscli interfaceGetter, nma podsPendingIPAssignment: bounded.NewTimedSet(250), // nolint:gomnd // maxpods EndpointStateStore: endpointStateStore, EndpointState: make(map[string]*EndpointInfo), + homeAzMonitor: homeAzMonitor, cniConflistGenerator: gen, }, nil } @@ -253,6 +256,7 @@ func (service *HTTPRestService) Init(config *common.ServiceConfig) error { listener.AddHandler(cns.PathDebugPodContext, service.handleDebugPodContext) listener.AddHandler(cns.PathDebugRestData, service.handleDebugRestData) listener.AddHandler(cns.NetworkContainersURLPath, service.getOrRefreshNetworkContainers) + listener.AddHandler(cns.GetHomeAz, service.getHomeAz) // handlers for v0.2 listener.AddHandler(cns.V2Prefix+cns.SetEnvironmentPath, service.setEnvironment) @@ -276,6 +280,7 @@ func (service *HTTPRestService) Init(config *common.ServiceConfig) error { listener.AddHandler(cns.V2Prefix+cns.CreateHostNCApipaEndpointPath, service.createHostNCApipaEndpoint) listener.AddHandler(cns.V2Prefix+cns.DeleteHostNCApipaEndpointPath, service.deleteHostNCApipaEndpoint) listener.AddHandler(cns.V2Prefix+cns.NmAgentSupportedApisPath, service.nmAgentSupportedApisHandler) + listener.AddHandler(cns.V2Prefix+cns.GetHomeAz, service.getHomeAz) // Initialize HTTP client to be reused in CNS connectionTimeout, _ := service.GetOption(acn.OptHttpConnectionTimeout).(int) diff --git a/cns/restserver/restserver_test.go b/cns/restserver/restserver_test.go index 0a575a7c2d..c79f35db76 100644 --- a/cns/restserver/restserver_test.go +++ b/cns/restserver/restserver_test.go @@ -1,6 +1,8 @@ package restserver -import "github.com/Azure/azure-container-networking/cns/fakes" +import ( + "github.com/Azure/azure-container-networking/cns/fakes" +) func setMockNMAgent(h *HTTPRestService, m *fakes.NMAgentClientFake) func() { // this is a hack that exists because the tests are too DRY, so the setup diff --git a/cns/restserver/util.go b/cns/restserver/util.go index cdf3df82f6..e1e7f34f23 100644 --- a/cns/restserver/util.go +++ b/cns/restserver/util.go @@ -923,3 +923,9 @@ func (service *HTTPRestService) createNetworkContainers(createNetworkContainerRe Message: "", } } + +// setResponse encodes the http response +func (service *HTTPRestService) setResponse(w http.ResponseWriter, returnCode types.ResponseCode, response interface{}) { + serviceErr := service.Listener.Encode(w, &response) + logger.Response(service.Name, response, returnCode, serviceErr) +} diff --git a/cns/service/main.go b/cns/service/main.go index 2101e2e574..645b32ad47 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -522,6 +522,10 @@ func main() { return } + homeAzMonitor := restserver.NewHomeAzMonitor(nmaClient, time.Duration(cnsconfig.PopulateHomeAzCacheRetryIntervalSecs)*time.Second) + logger.Printf("start the goroutine for refreshing homeAz") + homeAzMonitor.Start() + if cnsconfig.ChannelMode == cns.Managed { config.ChannelMode = cns.Managed privateEndpoint = cnsconfig.ManagedSettings.PrivateEndpoint @@ -609,7 +613,7 @@ func main() { // Create CNS object. httpRestService, err := restserver.NewHTTPRestService(&config, &wireserver.Client{HTTPClient: &http.Client{}}, nmaClient, - endpointStateStore, conflistGenerator) + endpointStateStore, conflistGenerator, homeAzMonitor) if err != nil { logger.Errorf("Failed to create CNS object, err:%v.\n", err) return @@ -835,6 +839,9 @@ func main() { } } + logger.Printf("end the goroutine for refreshing homeAz") + homeAzMonitor.Stop() + logger.Printf("stop cns service") // Cleanup. if httpRestService != nil { diff --git a/cns/types/codes.go b/cns/types/codes.go index b61e1b3ba5..fe5eb57b5c 100644 --- a/cns/types/codes.go +++ b/cns/types/codes.go @@ -40,6 +40,8 @@ const ( UnsupportedNCVersion ResponseCode = 38 FailedToRunIPTableCmd ResponseCode = 39 NilEndpointStateStore ResponseCode = 40 + NmAgentInternalServerError ResponseCode = 41 + StatusUnauthorized ResponseCode = 42 UnexpectedError ResponseCode = 99 ) @@ -116,6 +118,10 @@ func (c ResponseCode) String() string { return "UnsupportedOrchestratorType" case UnsupportedVerb: return "UnsupportedVerb" + case NmAgentInternalServerError: + return "NmAgentInternalServerError" + case StatusUnauthorized: + return "StatusUnauthorized" default: return "UnknownError" } diff --git a/go.mod b/go.mod index 73b5211fb8..a7b8c0b5a4 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/nxadm/tail v1.4.8 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.18.1 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.12.2 github.com/prometheus/client_model v0.2.0 diff --git a/go.sum b/go.sum index 936b2512e5..f9177365dd 100644 --- a/go.sum +++ b/go.sum @@ -684,6 +684,8 @@ github.com/opencontainers/selinux v1.8.0/go.mod h1:RScLhm78qiWa2gbVCcGkC7tCGdgk3 github.com/opencontainers/selinux v1.8.2/go.mod h1:MUIHuUEvKB1wtJjQdOyYRgOnLD2xAPP8dBsCoU0KuF8= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= diff --git a/nmagent/client.go b/nmagent/client.go index 346f8724c8..f7f43f80ce 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -235,6 +235,33 @@ func (c *Client) GetNCVersionList(ctx context.Context) (NCVersionList, error) { return out, nil } +// GetHomeAz gets node's home az from nmagent +func (c *Client) GetHomeAz(ctx context.Context) (AzResponse, error) { + getHomeAzRequest := &GetHomeAzRequest{} + var homeAzResponse AzResponse + req, err := c.buildRequest(ctx, getHomeAzRequest) + if err != nil { + return homeAzResponse, errors.Wrap(err, "building request") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return homeAzResponse, errors.Wrap(err, "submitting request") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return homeAzResponse, die(resp.StatusCode, resp.Header, resp.Body) + } + + err = json.NewDecoder(resp.Body).Decode(&homeAzResponse) + if err != nil { + return homeAzResponse, errors.Wrap(err, "decoding response") + } + + return homeAzResponse, nil +} + func die(code int, headers http.Header, body io.ReadCloser) error { // nolint:errcheck // make a best effort to return whatever information we can // returning an error here without the code and source would diff --git a/nmagent/client_test.go b/nmagent/client_test.go index e8f7c2fd19..b055b5b1f1 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -622,3 +622,71 @@ func TestGetNCVersionList(t *testing.T) { }) } } + +func TestGetHomeAz(t *testing.T) { + tests := []struct { + name string + exp nmagent.AzResponse + expPath string + resp map[string]interface{} + shouldErr bool + }{ + { + "happy path", + nmagent.AzResponse{HomeAz: uint(1)}, + "/machine/plugins/?comp=nmagent&type=GetHomeAz", + map[string]interface{}{ + "httpStatusCode": "200", + "HomeAz": 1, + }, + false, + }, + { + "empty response", + nmagent.AzResponse{}, + "/machine/plugins/?comp=nmagent&type=GetHomeAz", + map[string]interface{}{ + "httpStatusCode": "500", + }, + true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var gotPath string + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + gotPath = req.URL.Path + rr := httptest.NewRecorder() + err := json.NewEncoder(rr).Encode(test.resp) + if err != nil { + t.Fatal("unexpected error encoding response: err:", err) + } + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }) + + got, err := client.GetHomeAz(context.TODO()) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if gotPath != test.expPath { + t.Error("paths differ: got:", gotPath, "exp:", test.expPath) + } + + if !cmp.Equal(got, test.exp) { + t.Error("response differs from expectation: diff:", cmp.Diff(got, test.exp)) + } + }) + } +} diff --git a/nmagent/requests.go b/nmagent/requests.go index a7809e0b14..a24846053b 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -387,3 +387,29 @@ func (NCVersionListRequest) Validate() error { // cannot be made invalid it's fine for this to simply... return nil } + +var _ Request = &GetHomeAzRequest{} + +type GetHomeAzRequest struct{} + +// Body is a no-op method to satisfy the Request interface while indicating +// that there is no body for a GetHomeAz Request. +func (g *GetHomeAzRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method indicates that GetHomeAz requests are GET requests. +func (g *GetHomeAzRequest) Method() string { + return http.MethodGet +} + +// Path returns the necessary URI path for invoking a GetHomeAz request. +func (g *GetHomeAzRequest) Path() string { + return "/GetHomeAz" +} + +// Validate is a no-op method because GetHomeAzRequest have no parameters, +// and therefore can never be invalid. +func (g *GetHomeAzRequest) Validate() error { + return nil +} diff --git a/nmagent/responses.go b/nmagent/responses.go index 9dc8b3ac1d..e5324d59f9 100644 --- a/nmagent/responses.go +++ b/nmagent/responses.go @@ -36,3 +36,7 @@ type NCVersion struct { type NCVersionList struct { Containers []NCVersion `json:"networkContainers"` } + +type AzResponse struct { + HomeAz uint `json:"homeAz"` +} diff --git a/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS new file mode 100644 index 0000000000..2b16e99741 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS @@ -0,0 +1,9 @@ +This is a list of people who have contributed code to go-cache. They, or their +employers, are the copyright holders of the contributed code. Contributed code +is subject to the license restrictions listed in LICENSE (as they were when the +code was contributed.) + +Dustin Sallings +Jason Mooberry +Sergey Shepelev +Alex Edwards diff --git a/vendor/github.com/patrickmn/go-cache/LICENSE b/vendor/github.com/patrickmn/go-cache/LICENSE new file mode 100644 index 0000000000..db9903c75c --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2012-2017 Patrick Mylund Nielsen and the go-cache contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/patrickmn/go-cache/README.md b/vendor/github.com/patrickmn/go-cache/README.md new file mode 100644 index 0000000000..c5789cc66c --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/README.md @@ -0,0 +1,83 @@ +# go-cache + +go-cache is an in-memory key:value store/cache similar to memcached that is +suitable for applications running on a single machine. Its major advantage is +that, being essentially a thread-safe `map[string]interface{}` with expiration +times, it doesn't need to serialize or transmit its contents over the network. + +Any object can be stored, for a given duration or forever, and the cache can be +safely used by multiple goroutines. + +Although go-cache isn't meant to be used as a persistent datastore, the entire +cache can be saved to and loaded from a file (using `c.Items()` to retrieve the +items map to serialize, and `NewFrom()` to create a cache from a deserialized +one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.) + +### Installation + +`go get github.com/patrickmn/go-cache` + +### Usage + +```go +import ( + "fmt" + "github.com/patrickmn/go-cache" + "time" +) + +func main() { + // Create a cache with a default expiration time of 5 minutes, and which + // purges expired items every 10 minutes + c := cache.New(5*time.Minute, 10*time.Minute) + + // Set the value of the key "foo" to "bar", with the default expiration time + c.Set("foo", "bar", cache.DefaultExpiration) + + // Set the value of the key "baz" to 42, with no expiration time + // (the item won't be removed until it is re-set, or removed using + // c.Delete("baz") + c.Set("baz", 42, cache.NoExpiration) + + // Get the string associated with the key "foo" from the cache + foo, found := c.Get("foo") + if found { + fmt.Println(foo) + } + + // Since Go is statically typed, and cache values can be anything, type + // assertion is needed when values are being passed to functions that don't + // take arbitrary types, (i.e. interface{}). The simplest way to do this for + // values which will only be used once--e.g. for passing to another + // function--is: + foo, found := c.Get("foo") + if found { + MyFunction(foo.(string)) + } + + // This gets tedious if the value is used several times in the same function. + // You might do either of the following instead: + if x, found := c.Get("foo"); found { + foo := x.(string) + // ... + } + // or + var foo string + if x, found := c.Get("foo"); found { + foo = x.(string) + } + // ... + // foo can then be passed around freely as a string + + // Want performance? Store pointers! + c.Set("foo", &MyStruct, cache.DefaultExpiration) + if x, found := c.Get("foo"); found { + foo := x.(*MyStruct) + // ... + } +} +``` + +### Reference + +`godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache) diff --git a/vendor/github.com/patrickmn/go-cache/cache.go b/vendor/github.com/patrickmn/go-cache/cache.go new file mode 100644 index 0000000000..db88d2f2cb --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/cache.go @@ -0,0 +1,1161 @@ +package cache + +import ( + "encoding/gob" + "fmt" + "io" + "os" + "runtime" + "sync" + "time" +) + +type Item struct { + Object interface{} + Expiration int64 +} + +// Returns true if the item has expired. +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +const ( + // For use with functions that take an expiration time. + NoExpiration time.Duration = -1 + // For use with functions that take an expiration time. Equivalent to + // passing in the same expiration duration as was given to New() or + // NewFrom() when the cache was created (e.g. 5 minutes.) + DefaultExpiration time.Duration = 0 +) + +type Cache struct { + *cache + // If this is confusing, see the comment at the bottom of New() +} + +type cache struct { + defaultExpiration time.Duration + items map[string]Item + mu sync.RWMutex + onEvicted func(string, interface{}) + janitor *janitor +} + +// Add an item to the cache, replacing any existing item. If the duration is 0 +// (DefaultExpiration), the cache's default expiration time is used. If it is -1 +// (NoExpiration), the item never expires. +func (c *cache) Set(k string, x interface{}, d time.Duration) { + // "Inlining" of set + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.mu.Lock() + c.items[k] = Item{ + Object: x, + Expiration: e, + } + // TODO: Calls to mu.Unlock are currently not deferred because defer + // adds ~200 ns (as of go1.) + c.mu.Unlock() +} + +func (c *cache) set(k string, x interface{}, d time.Duration) { + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.items[k] = Item{ + Object: x, + Expiration: e, + } +} + +// Add an item to the cache, replacing any existing item, using the default +// expiration. +func (c *cache) SetDefault(k string, x interface{}) { + c.Set(k, x, DefaultExpiration) +} + +// Add an item to the cache only if an item doesn't already exist for the given +// key, or if the existing item has expired. Returns an error otherwise. +func (c *cache) Add(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if found { + c.mu.Unlock() + return fmt.Errorf("Item %s already exists", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Set a new value for the cache key only if it already exists, and the existing +// item hasn't expired. Returns an error otherwise. +func (c *cache) Replace(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if !found { + c.mu.Unlock() + return fmt.Errorf("Item %s doesn't exist", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Get an item from the cache. Returns the item or nil, and a bool indicating +// whether the key was found. +func (c *cache) Get(k string) (interface{}, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, false + } + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, false + } + } + c.mu.RUnlock() + return item.Object, true +} + +// GetWithExpiration returns an item and its expiration time from the cache. +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + // Return the item and the expiration time + c.mu.RUnlock() + return item.Object, time.Unix(0, item.Expiration), true + } + + // If expiration <= 0 (i.e. no expiration time set) then return the item + // and a zeroed time.Time + c.mu.RUnlock() + return item.Object, time.Time{}, true +} + +func (c *cache) get(k string) (interface{}, bool) { + item, found := c.items[k] + if !found { + return nil, false + } + // "Inlining" of Expired + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + return nil, false + } + } + return item.Object, true +} + +// Increment an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to increment it by n. To retrieve the incremented value, use one +// of the specialized methods, e.g. IncrementInt64. +func (c *cache) Increment(k string, n int64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) + int(n) + case int8: + v.Object = v.Object.(int8) + int8(n) + case int16: + v.Object = v.Object.(int16) + int16(n) + case int32: + v.Object = v.Object.(int32) + int32(n) + case int64: + v.Object = v.Object.(int64) + n + case uint: + v.Object = v.Object.(uint) + uint(n) + case uintptr: + v.Object = v.Object.(uintptr) + uintptr(n) + case uint8: + v.Object = v.Object.(uint8) + uint8(n) + case uint16: + v.Object = v.Object.(uint16) + uint16(n) + case uint32: + v.Object = v.Object.(uint32) + uint32(n) + case uint64: + v.Object = v.Object.(uint64) + uint64(n) + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to increment it by n. Pass a negative number to decrement the +// value. To retrieve the incremented value, use one of the specialized methods, +// e.g. IncrementFloat64. +func (c *cache) IncrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint8 by n. Returns an error if the item's value +// is not an uint8, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to decrement it by n. To retrieve the decremented value, use one +// of the specialized methods, e.g. DecrementInt64. +func (c *cache) Decrement(k string, n int64) error { + // TODO: Implement Increment and Decrement more cleanly. + // (Cannot do Increment(k, n*-1) for uints.) + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item not found") + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) - int(n) + case int8: + v.Object = v.Object.(int8) - int8(n) + case int16: + v.Object = v.Object.(int16) - int16(n) + case int32: + v.Object = v.Object.(int32) - int32(n) + case int64: + v.Object = v.Object.(int64) - n + case uint: + v.Object = v.Object.(uint) - uint(n) + case uintptr: + v.Object = v.Object.(uintptr) - uintptr(n) + case uint8: + v.Object = v.Object.(uint8) - uint8(n) + case uint16: + v.Object = v.Object.(uint16) - uint16(n) + case uint32: + v.Object = v.Object.(uint32) - uint32(n) + case uint64: + v.Object = v.Object.(uint64) - uint64(n) + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to decrement it by n. Pass a negative number to decrement the +// value. To retrieve the decremented value, use one of the specialized methods, +// e.g. DecrementFloat64. +func (c *cache) DecrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint8 by n. Returns an error if the item's value is +// not an uint8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Delete an item from the cache. Does nothing if the key is not in the cache. +func (c *cache) Delete(k string) { + c.mu.Lock() + v, evicted := c.delete(k) + c.mu.Unlock() + if evicted { + c.onEvicted(k, v) + } +} + +func (c *cache) delete(k string) (interface{}, bool) { + if c.onEvicted != nil { + if v, found := c.items[k]; found { + delete(c.items, k) + return v.Object, true + } + } + delete(c.items, k) + return nil, false +} + +type keyAndValue struct { + key string + value interface{} +} + +// Delete all expired items from the cache. +func (c *cache) DeleteExpired() { + var evictedItems []keyAndValue + now := time.Now().UnixNano() + c.mu.Lock() + for k, v := range c.items { + // "Inlining" of expired + if v.Expiration > 0 && now > v.Expiration { + ov, evicted := c.delete(k) + if evicted { + evictedItems = append(evictedItems, keyAndValue{k, ov}) + } + } + } + c.mu.Unlock() + for _, v := range evictedItems { + c.onEvicted(v.key, v.value) + } +} + +// Sets an (optional) function that is called with the key and value when an +// item is evicted from the cache. (Including when it is deleted manually, but +// not when it is overwritten.) Set to nil to disable. +func (c *cache) OnEvicted(f func(string, interface{})) { + c.mu.Lock() + c.onEvicted = f + c.mu.Unlock() +} + +// Write the cache's items (using Gob) to an io.Writer. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Save(w io.Writer) (err error) { + enc := gob.NewEncoder(w) + defer func() { + if x := recover(); x != nil { + err = fmt.Errorf("Error registering item types with Gob library") + } + }() + c.mu.RLock() + defer c.mu.RUnlock() + for _, v := range c.items { + gob.Register(v.Object) + } + err = enc.Encode(&c.items) + return +} + +// Save the cache's items to the given filename, creating the file if it +// doesn't exist, and overwriting it if it does. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) SaveFile(fname string) error { + fp, err := os.Create(fname) + if err != nil { + return err + } + err = c.Save(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Add (Gob-serialized) cache items from an io.Reader, excluding any items with +// keys that already exist (and haven't expired) in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Load(r io.Reader) error { + dec := gob.NewDecoder(r) + items := map[string]Item{} + err := dec.Decode(&items) + if err == nil { + c.mu.Lock() + defer c.mu.Unlock() + for k, v := range items { + ov, found := c.items[k] + if !found || ov.Expired() { + c.items[k] = v + } + } + } + return err +} + +// Load and add cache items from the given filename, excluding any items with +// keys that already exist in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) LoadFile(fname string) error { + fp, err := os.Open(fname) + if err != nil { + return err + } + err = c.Load(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Copies all unexpired items in the cache into a new map and returns it. +func (c *cache) Items() map[string]Item { + c.mu.RLock() + defer c.mu.RUnlock() + m := make(map[string]Item, len(c.items)) + now := time.Now().UnixNano() + for k, v := range c.items { + // "Inlining" of Expired + if v.Expiration > 0 { + if now > v.Expiration { + continue + } + } + m[k] = v + } + return m +} + +// Returns the number of items in the cache. This may include items that have +// expired, but have not yet been cleaned up. +func (c *cache) ItemCount() int { + c.mu.RLock() + n := len(c.items) + c.mu.RUnlock() + return n +} + +// Delete all items from the cache. +func (c *cache) Flush() { + c.mu.Lock() + c.items = map[string]Item{} + c.mu.Unlock() +} + +type janitor struct { + Interval time.Duration + stop chan bool +} + +func (j *janitor) Run(c *cache) { + ticker := time.NewTicker(j.Interval) + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- true +} + +func runJanitor(c *cache, ci time.Duration) { + j := &janitor{ + Interval: ci, + stop: make(chan bool), + } + c.janitor = j + go j.Run(c) +} + +func newCache(de time.Duration, m map[string]Item) *cache { + if de == 0 { + de = -1 + } + c := &cache{ + defaultExpiration: de, + items: m, + } + return c +} + +func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) *Cache { + c := newCache(de, m) + // This trick ensures that the janitor goroutine (which--granted it + // was enabled--is running DeleteExpired on c forever) does not keep + // the returned C object from being garbage collected. When it is + // garbage collected, the finalizer stops the janitor goroutine, after + // which c can be collected. + C := &Cache{c} + if ci > 0 { + runJanitor(c, ci) + runtime.SetFinalizer(C, stopJanitor) + } + return C +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +func New(defaultExpiration, cleanupInterval time.Duration) *Cache { + items := make(map[string]Item) + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +// +// NewFrom() also accepts an items map which will serve as the underlying map +// for the cache. This is useful for starting from a deserialized cache +// (serialized using e.g. gob.Encode() on c.Items()), or passing in e.g. +// make(map[string]Item, 500) to improve startup performance when the cache +// is expected to reach a certain minimum size. +// +// Only the cache's methods synchronize access to this map, so it is not +// recommended to keep any references to the map around after creating a cache. +// If need be, the map can be accessed at a later point using c.Items() (subject +// to the same caveat.) +// +// Note regarding serialization: When using e.g. gob, make sure to +// gob.Register() the individual types stored in the cache before encoding a +// map retrieved with c.Items(), and to register those same types before +// decoding a blob containing an items map. +func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache { + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} diff --git a/vendor/github.com/patrickmn/go-cache/sharded.go b/vendor/github.com/patrickmn/go-cache/sharded.go new file mode 100644 index 0000000000..bcc0538bcc --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/sharded.go @@ -0,0 +1,192 @@ +package cache + +import ( + "crypto/rand" + "math" + "math/big" + insecurerand "math/rand" + "os" + "runtime" + "time" +) + +// This is an experimental and unexported (for now) attempt at making a cache +// with better algorithmic complexity than the standard one, namely by +// preventing write locks of the entire cache when an item is added. As of the +// time of writing, the overhead of selecting buckets results in cache +// operations being about twice as slow as for the standard cache with small +// total cache sizes, and faster for larger ones. +// +// See cache_test.go for a few benchmarks. + +type unexportedShardedCache struct { + *shardedCache +} + +type shardedCache struct { + seed uint32 + m uint32 + cs []*cache + janitor *shardedJanitor +} + +// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead. +func djb33(seed uint32, k string) uint32 { + var ( + l = uint32(len(k)) + d = 5381 + seed + l + i = uint32(0) + ) + // Why is all this 5x faster than a for loop? + if l >= 4 { + for i < l-4 { + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + d = (d * 33) ^ uint32(k[i+3]) + i += 4 + } + } + switch l - i { + case 1: + case 2: + d = (d * 33) ^ uint32(k[i]) + case 3: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + case 4: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + } + return d ^ (d >> 16) +} + +func (sc *shardedCache) bucket(k string) *cache { + return sc.cs[djb33(sc.seed, k)%sc.m] +} + +func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) { + sc.bucket(k).Set(k, x, d) +} + +func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Add(k, x, d) +} + +func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Replace(k, x, d) +} + +func (sc *shardedCache) Get(k string) (interface{}, bool) { + return sc.bucket(k).Get(k) +} + +func (sc *shardedCache) Increment(k string, n int64) error { + return sc.bucket(k).Increment(k, n) +} + +func (sc *shardedCache) IncrementFloat(k string, n float64) error { + return sc.bucket(k).IncrementFloat(k, n) +} + +func (sc *shardedCache) Decrement(k string, n int64) error { + return sc.bucket(k).Decrement(k, n) +} + +func (sc *shardedCache) Delete(k string) { + sc.bucket(k).Delete(k) +} + +func (sc *shardedCache) DeleteExpired() { + for _, v := range sc.cs { + v.DeleteExpired() + } +} + +// Returns the items in the cache. This may include items that have expired, +// but have not yet been cleaned up. If this is significant, the Expiration +// fields of the items should be checked. Note that explicit synchronization +// is needed to use a cache and its corresponding Items() return values at +// the same time, as the maps are shared. +func (sc *shardedCache) Items() []map[string]Item { + res := make([]map[string]Item, len(sc.cs)) + for i, v := range sc.cs { + res[i] = v.Items() + } + return res +} + +func (sc *shardedCache) Flush() { + for _, v := range sc.cs { + v.Flush() + } +} + +type shardedJanitor struct { + Interval time.Duration + stop chan bool +} + +func (j *shardedJanitor) Run(sc *shardedCache) { + j.stop = make(chan bool) + tick := time.Tick(j.Interval) + for { + select { + case <-tick: + sc.DeleteExpired() + case <-j.stop: + return + } + } +} + +func stopShardedJanitor(sc *unexportedShardedCache) { + sc.janitor.stop <- true +} + +func runShardedJanitor(sc *shardedCache, ci time.Duration) { + j := &shardedJanitor{ + Interval: ci, + } + sc.janitor = j + go j.Run(sc) +} + +func newShardedCache(n int, de time.Duration) *shardedCache { + max := big.NewInt(0).SetUint64(uint64(math.MaxUint32)) + rnd, err := rand.Int(rand.Reader, max) + var seed uint32 + if err != nil { + os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n")) + seed = insecurerand.Uint32() + } else { + seed = uint32(rnd.Uint64()) + } + sc := &shardedCache{ + seed: seed, + m: uint32(n), + cs: make([]*cache, n), + } + for i := 0; i < n; i++ { + c := &cache{ + defaultExpiration: de, + items: map[string]Item{}, + } + sc.cs[i] = c + } + return sc +} + +func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache { + if defaultExpiration == 0 { + defaultExpiration = -1 + } + sc := newShardedCache(shards, defaultExpiration) + SC := &unexportedShardedCache{sc} + if cleanupInterval > 0 { + runShardedJanitor(sc, cleanupInterval) + runtime.SetFinalizer(SC, stopShardedJanitor) + } + return SC +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 1030a7aed2..2ae4a9fc81 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -346,6 +346,9 @@ github.com/onsi/gomega/matchers/support/goraph/edge github.com/onsi/gomega/matchers/support/goraph/node github.com/onsi/gomega/matchers/support/goraph/util github.com/onsi/gomega/types +# github.com/patrickmn/go-cache v2.1.0+incompatible +## explicit +github.com/patrickmn/go-cache # github.com/pelletier/go-toml v1.9.5 ## explicit; go 1.12 github.com/pelletier/go-toml