diff --git a/gateway/api.go b/gateway/api.go index 9ba49860d59..aa2d18cd051 100644 --- a/gateway/api.go +++ b/gateway/api.go @@ -892,7 +892,7 @@ func (gw *Gateway) handleRemoveSortedSetRange(keyName, scoreFrom, scoreTo string } func (gw *Gateway) handleGetPolicy(polID string) (interface{}, int) { - if pol := gw.getPolicy(polID); pol.ID != "" { + if pol, ok := gw.PolicyByID(polID); ok && pol.ID != "" { return pol, http.StatusOK } diff --git a/gateway/gateway.go b/gateway/gateway.go new file mode 100644 index 00000000000..36b65e969c8 --- /dev/null +++ b/gateway/gateway.go @@ -0,0 +1,38 @@ +package gateway + +import ( + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +type Repository interface { + policy.Repository +} + +var _ Repository = &Gateway{} + +func (gw *Gateway) PolicyIDs() []string { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + result := make([]string, 0, len(gw.policiesByID)) + for id := range gw.policiesByID { + result = append(result, id) + } + return result +} + +func (gw *Gateway) PolicyByID(polID string) (user.Policy, bool) { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + pol, ok := gw.policiesByID[polID] + return pol, ok +} + +func (gw *Gateway) PolicyCount() int { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + return len(gw.policiesByID) +} diff --git a/gateway/middleware.go b/gateway/middleware.go index 92b21dcae96..91e84d960bc 100644 --- a/gateway/middleware.go +++ b/gateway/middleware.go @@ -15,6 +15,7 @@ import ( "github.com/TykTechnologies/tyk/internal/cache" "github.com/TykTechnologies/tyk/internal/event" "github.com/TykTechnologies/tyk/internal/otel" + "github.com/TykTechnologies/tyk/internal/policy" "github.com/TykTechnologies/tyk/rpc" "github.com/TykTechnologies/tyk/header" @@ -346,420 +347,15 @@ func (t *BaseMiddleware) UpdateRequestSession(r *http.Request) bool { return true } -// clearSession clears the quota, rate limit and complexity values so that partitioned policies can apply their values. -// Otherwise, if the session has already a higher value, an applied policy will not win, and its values will be ignored. -func (t *BaseMiddleware) clearSession(session *user.SessionState) { - policies := session.PolicyIDs() - for _, polID := range policies { - t.Gw.policiesMu.RLock() - policy, ok := t.Gw.policiesByID[polID] - t.Gw.policiesMu.RUnlock() - if !ok { - continue - } - - all := !(policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) - - if policy.Partitions.Quota || all { - session.QuotaMax = 0 - session.QuotaRemaining = 0 - } - - if policy.Partitions.RateLimit || all { - session.Rate = 0 - session.Per = 0 - session.ThrottleRetryLimit = 0 - session.ThrottleInterval = 0 - } - - if policy.Partitions.Complexity || all { - session.MaxQueryDepth = 0 - } - } -} - // ApplyPolicies will check if any policies are loaded. If any are, it // will overwrite the session state to use the policy values. func (t *BaseMiddleware) ApplyPolicies(session *user.SessionState) error { - rights := make(map[string]user.AccessDefinition) - tags := make(map[string]bool) - if session.MetaData == nil { - session.MetaData = make(map[string]interface{}) - } - - t.clearSession(session) - - didQuota, didRateLimit, didACL, didComplexity := make(map[string]bool), make(map[string]bool), make(map[string]bool), make(map[string]bool) - - var ( - err error - lookupMap map[string]user.Policy - policyIDs []string - ) - - customPolicies, err := session.CustomPolicies() - if err != nil { - policyIDs = session.PolicyIDs() - t.Gw.policiesMu.RLock() - lookupMap = t.Gw.policiesByID - defer t.Gw.policiesMu.RUnlock() - } else { - lookupMap = customPolicies - policyIDs = make([]string, 0, len(customPolicies)) - for _, val := range customPolicies { - policyIDs = append(policyIDs, val.ID) - } - } - - for _, polID := range policyIDs { - policy, ok := lookupMap[polID] - if !ok { - err := fmt.Errorf("policy not found: %q", polID) - t.Logger().Error(err) - if len(policyIDs) > 1 { - continue - } - - return err - } - // Check ownership, policy org owner must be the same as API, - // otherwise you could overwrite a session key with a policy from a different org! - if t.Spec != nil && policy.OrgID != t.Spec.OrgID { - err := fmt.Errorf("attempting to apply policy from different organisation to key, skipping") - t.Logger().Error(err) - return err - } - - if policy.Partitions.PerAPI && - (policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) { - err := fmt.Errorf("cannot apply policy %s which has per_api and any of partitions set", policy.ID) - log.Error(err) - return err - } - - if policy.Partitions.PerAPI { - for apiID, accessRights := range policy.AccessRights { - // new logic when you can specify quota or rate in more than one policy but for different APIs - if didQuota[apiID] || didRateLimit[apiID] || didACL[apiID] || didComplexity[apiID] { // no other partitions allowed - err := fmt.Errorf("cannot apply multiple policies when some have per_api set and some are partitioned") - log.Error(err) - return err - } - - idForScope := apiID - // check if we don't have limit on API level specified when policy was created - if accessRights.Limit.IsEmpty() { - // limit was not specified on API level so we will populate it from policy - idForScope = policy.ID - accessRights.Limit = policy.APILimit() - } - accessRights.AllowanceScope = idForScope - accessRights.Limit.SetBy = idForScope - - // respect current quota renews (on API limit level) - if r, ok := session.AccessRights[apiID]; ok && !r.Limit.IsEmpty() { - accessRights.Limit.QuotaRenews = r.Limit.QuotaRenews - } - - if r, ok := session.AccessRights[apiID]; ok { - // If GQL introspection is disabled, keep that configuration. - if r.DisableIntrospection { - accessRights.DisableIntrospection = r.DisableIntrospection - } - } - - // overwrite session access right for this API - rights[apiID] = accessRights - - // identify that limit for that API is set (to allow set it only once) - didACL[apiID] = true - didQuota[apiID] = true - didRateLimit[apiID] = true - didComplexity[apiID] = true - } - } else { - usePartitions := policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity - - for k, v := range policy.AccessRights { - ar := v - - if !usePartitions || policy.Partitions.Acl { - didACL[k] = true - - ar.AllowedURLs = copyAllowedURLs(v.AllowedURLs) - - // Merge ACLs for the same API - if r, ok := rights[k]; ok { - // If GQL introspection is disabled, keep that configuration. - if v.DisableIntrospection { - r.DisableIntrospection = v.DisableIntrospection - } - r.Versions = appendIfMissing(rights[k].Versions, v.Versions...) - - for _, u := range v.AllowedURLs { - found := false - for ai, au := range r.AllowedURLs { - if u.URL == au.URL { - found = true - r.AllowedURLs[ai].Methods = appendIfMissing(au.Methods, u.Methods...) - } - } - - if !found { - r.AllowedURLs = append(r.AllowedURLs, v.AllowedURLs...) - } - } - - for _, t := range v.RestrictedTypes { - for ri, rt := range r.RestrictedTypes { - if t.Name == rt.Name { - r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) - } - } - } - - for _, t := range v.AllowedTypes { - for ri, rt := range r.AllowedTypes { - if t.Name == rt.Name { - r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) - } - } - } - - mergeFieldLimits := func(res *user.FieldLimits, new user.FieldLimits) { - if greaterThanInt(new.MaxQueryDepth, res.MaxQueryDepth) { - res.MaxQueryDepth = new.MaxQueryDepth - } - } - - for _, far := range v.FieldAccessRights { - exists := false - for i, rfar := range r.FieldAccessRights { - if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { - exists = true - mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) - } - } - - if !exists { - r.FieldAccessRights = append(r.FieldAccessRights, far) - } - } - - ar = r - } - - ar.Limit.SetBy = policy.ID - } - - if !usePartitions || policy.Partitions.Quota { - didQuota[k] = true - if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { - - ar.Limit.QuotaMax = policy.QuotaMax - if greaterThanInt64(policy.QuotaMax, session.QuotaMax) { - session.QuotaMax = policy.QuotaMax - } - } - - if policy.QuotaRenewalRate > ar.Limit.QuotaRenewalRate { - ar.Limit.QuotaRenewalRate = policy.QuotaRenewalRate - if policy.QuotaRenewalRate > session.QuotaRenewalRate { - session.QuotaRenewalRate = policy.QuotaRenewalRate - } - } - } - - if !usePartitions || policy.Partitions.RateLimit { - didRateLimit[k] = true - - apiLimits := ar.Limit - policyLimits := policy.APILimit() - sessionLimits := session.APILimit() - - // Update Rate, Per and Smoothing - if apiLimits.Less(policyLimits) { - ar.Limit.Rate = policyLimits.Rate - ar.Limit.Per = policyLimits.Per - ar.Limit.Smoothing = policyLimits.Smoothing - - if sessionLimits.Less(policyLimits) { - session.Rate = policyLimits.Rate - session.Per = policyLimits.Per - session.Smoothing = policyLimits.Smoothing - } - } - - if policy.ThrottleRetryLimit > ar.Limit.ThrottleRetryLimit { - ar.Limit.ThrottleRetryLimit = policy.ThrottleRetryLimit - if policy.ThrottleRetryLimit > session.ThrottleRetryLimit { - session.ThrottleRetryLimit = policy.ThrottleRetryLimit - } - } - - if policy.ThrottleInterval > ar.Limit.ThrottleInterval { - ar.Limit.ThrottleInterval = policy.ThrottleInterval - if policy.ThrottleInterval > session.ThrottleInterval { - session.ThrottleInterval = policy.ThrottleInterval - } - } - } - - if !usePartitions || policy.Partitions.Complexity { - didComplexity[k] = true - - if greaterThanInt(policy.MaxQueryDepth, ar.Limit.MaxQueryDepth) { - ar.Limit.MaxQueryDepth = policy.MaxQueryDepth - if greaterThanInt(policy.MaxQueryDepth, session.MaxQueryDepth) { - session.MaxQueryDepth = policy.MaxQueryDepth - } - } - } - - // Respect existing QuotaRenews - if r, ok := session.AccessRights[k]; ok && !r.Limit.IsEmpty() { - ar.Limit.QuotaRenews = r.Limit.QuotaRenews - } - - rights[k] = ar - } - - // Master policy case - if len(policy.AccessRights) == 0 { - if !usePartitions || policy.Partitions.RateLimit { - session.Rate = policy.Rate - session.Per = policy.Per - session.Smoothing = policy.Smoothing - session.ThrottleInterval = policy.ThrottleInterval - session.ThrottleRetryLimit = policy.ThrottleRetryLimit - } - - if !usePartitions || policy.Partitions.Complexity { - session.MaxQueryDepth = policy.MaxQueryDepth - } - - if !usePartitions || policy.Partitions.Quota { - session.QuotaMax = policy.QuotaMax - session.QuotaRenewalRate = policy.QuotaRenewalRate - } - } - - if !session.HMACEnabled { - session.HMACEnabled = policy.HMACEnabled - } - - if !session.EnableHTTPSignatureValidation { - session.EnableHTTPSignatureValidation = policy.EnableHTTPSignatureValidation - } - } - - session.IsInactive = session.IsInactive || policy.IsInactive - - for _, tag := range policy.Tags { - tags[tag] = true - } - - for k, v := range policy.MetaData { - session.MetaData[k] = v - } - - if policy.LastUpdated > session.LastUpdated { - session.LastUpdated = policy.LastUpdated - } - } - - for _, tag := range session.Tags { - tags[tag] = true - } - - // set tags - session.Tags = []string{} - for tag := range tags { - session.Tags = appendIfMissing(session.Tags, tag) - } - - if len(policyIDs) == 0 { - for apiID, accessRight := range session.AccessRights { - // check if the api in the session has per api limit - if !accessRight.Limit.IsEmpty() { - accessRight.AllowanceScope = apiID - session.AccessRights[apiID] = accessRight - } - } - } - - distinctACL := make(map[string]bool) - - for _, v := range rights { - if v.Limit.SetBy != "" { - distinctACL[v.Limit.SetBy] = true - } - } - - // If some APIs had only ACL partitions, inherit rest from session level - for k, v := range rights { - if !didACL[k] { - delete(rights, k) - continue - } - - if !didRateLimit[k] { - v.Limit.Rate = session.Rate - v.Limit.Per = session.Per - v.Limit.Smoothing = session.Smoothing - v.Limit.ThrottleInterval = session.ThrottleInterval - v.Limit.ThrottleRetryLimit = session.ThrottleRetryLimit - } - - if !didComplexity[k] { - v.Limit.MaxQueryDepth = session.MaxQueryDepth - } - - if !didQuota[k] { - v.Limit.QuotaMax = session.QuotaMax - v.Limit.QuotaRenewalRate = session.QuotaRenewalRate - v.Limit.QuotaRenews = session.QuotaRenews - } - - // If multime ACL - if len(distinctACL) > 1 { - if v.AllowanceScope == "" && v.Limit.SetBy != "" { - v.AllowanceScope = v.Limit.SetBy - } - } - - v.Limit.SetBy = "" - - rights[k] = v - } - - // If we have policies defining rules for one single API, update session root vars (legacy) - if len(didQuota) == 1 && len(didRateLimit) == 1 && len(didComplexity) == 1 { - for _, v := range rights { - if len(didRateLimit) == 1 { - session.Rate = v.Limit.Rate - session.Per = v.Limit.Per - session.Smoothing = v.Limit.Smoothing - } - - if len(didQuota) == 1 { - session.QuotaMax = v.Limit.QuotaMax - session.QuotaRenews = v.Limit.QuotaRenews - session.QuotaRenewalRate = v.Limit.QuotaRenewalRate - } - - if len(didComplexity) == 1 { - session.MaxQueryDepth = v.Limit.MaxQueryDepth - } - } + var orgID *string + if t.Spec != nil { + orgID = &t.Spec.OrgID } - - // Override session ACL if at least one policy define it - if len(didACL) > 0 { - session.AccessRights = rights - } - - return nil + store := policy.New(orgID, t.Gw, log) + return store.Apply(session) } func copyAllowedURLs(input []user.AccessSpec) []user.AccessSpec { @@ -926,7 +522,7 @@ func (t *BaseMiddleware) handleRateLimitFailure(r *http.Request, e event.Event, // Report in health check reportHealthValue(t.Spec, Throttle, "-1") - return errors.New(event.String(e)), http.StatusTooManyRequests + return errors.New(message), http.StatusTooManyRequests } func (t *BaseMiddleware) getAuthType() string { diff --git a/gateway/mw_rate_limiting.go b/gateway/mw_rate_limiting.go index 1b70f16270e..dd9ab045540 100644 --- a/gateway/mw_rate_limiting.go +++ b/gateway/mw_rate_limiting.go @@ -96,7 +96,7 @@ func (k *RateLimitAndQuotaCheck) ProcessRequest(w http.ResponseWriter, r *http.R switch reason { case sessionFailNone: case sessionFailRateLimit: - err, errCode := k.handleRateLimitFailure(r, event.RateLimitExceeded, "", rateLimitKey) + err, errCode := k.handleRateLimitFailure(r, event.RateLimitExceeded, "Rate Limit Exceeded", rateLimitKey) if throttleRetryLimit > 0 { for { ctxIncThrottleLevel(r, throttleRetryLimit) diff --git a/gateway/policy_test.go b/gateway/policy_test.go index 9d5316f0398..73b016cd881 100644 --- a/gateway/policy_test.go +++ b/gateway/policy_test.go @@ -773,9 +773,12 @@ func (s *Test) TestPrepareApplyPolicies() (*BaseMiddleware, []testApplyPoliciesD }, } + gotPolicy, ok := s.Gw.PolicyByID("per-path2") + + assert.True(t, ok) assert.Equal(t, user.AccessSpec{ URL: "/user", Methods: []string{"GET"}, - }, s.Gw.getPolicy("per-path2").AccessRights["a"].AllowedURLs[0]) + }, gotPolicy.AccessRights["a"].AllowedURLs[0]) assert.Equal(t, want, sess.AccessRights) }, @@ -1193,7 +1196,6 @@ func TestApplyMultiPolicies(t *testing.T) { ts := StartTest(nil) defer ts.Close() - ts.Gw.policiesMu.RLock() policy1 := user.Policy{ ID: "policy1", Rate: 1000, @@ -1208,6 +1210,8 @@ func TestApplyMultiPolicies(t *testing.T) { }, } + assert.True(t, !policy1.APILimit().IsEmpty()) + policy2 := user.Policy{ ID: "policy2", Rate: 100, @@ -1225,11 +1229,14 @@ func TestApplyMultiPolicies(t *testing.T) { }, } + assert.True(t, !policy2.APILimit().IsEmpty()) + + ts.Gw.policiesMu.Lock() ts.Gw.policiesByID = map[string]user.Policy{ "policy1": policy1, "policy2": policy2, } - ts.Gw.policiesMu.RUnlock() + ts.Gw.policiesMu.Unlock() // load APIs ts.Gw.BuildAndLoadAPI( @@ -1265,7 +1272,13 @@ func TestApplyMultiPolicies(t *testing.T) { // create key key := uuid.New() ts.Run(t, []test.TestCase{ - {Method: http.MethodPost, Path: "/tyk/keys/" + key, Data: session, AdminAuth: true, Code: 200}, + { + Method: http.MethodPost, + Path: "/tyk/keys/" + key, + Data: session, + AdminAuth: true, + Code: 200, + }, }...) // run requests to different APIs diff --git a/gateway/rpc_storage_handler.go b/gateway/rpc_storage_handler.go index c39e4bed2e5..496646ea9ce 100644 --- a/gateway/rpc_storage_handler.go +++ b/gateway/rpc_storage_handler.go @@ -171,7 +171,7 @@ func (r *RPCStorageHandler) buildNodeInfo() []byte { Health: r.Gw.getHealthCheckInfo(), Stats: apidef.GWStats{ APIsCount: r.Gw.apisByIDLen(), - PoliciesCount: r.Gw.policiesByIDLen(), + PoliciesCount: r.Gw.PolicyCount(), }, HostDetails: im.HostDetails{ Hostname: r.Gw.hostDetails.Hostname, diff --git a/gateway/server.go b/gateway/server.go index 849f23e2ecf..d3908aea2eb 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -309,19 +309,6 @@ func (gw *Gateway) getAPIDefinition(apiID string) (*apidef.APIDefinition, error) return apiSpec.APIDefinition, nil } -func (gw *Gateway) getPolicy(polID string) user.Policy { - gw.policiesMu.RLock() - pol := gw.policiesByID[polID] - gw.policiesMu.RUnlock() - return pol -} - -func (gw *Gateway) policiesByIDLen() int { - gw.policiesMu.RLock() - defer gw.policiesMu.RUnlock() - return len(gw.policiesByID) -} - func (gw *Gateway) apisByIDLen() int { gw.apisMu.RLock() defer gw.apisMu.RUnlock() diff --git a/gateway/server_test.go b/gateway/server_test.go index 6b4636b844f..152fabe5ce6 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -153,7 +153,7 @@ func TestGateway_policiesByIDLen(t *testing.T) { }) } - actual := ts.Gw.policiesByIDLen() + actual := ts.Gw.PolicyCount() assert.Equal(t, tc.expected, actual) }) diff --git a/internal/event/event.go b/internal/event/event.go index 6bdd42a2a61..e0ce929af69 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -57,7 +57,6 @@ const ( // eventMap contains a map of events to a readable title for the event. // The title value should not contain ending punctuation. var eventMap = map[Event]string{ - RateLimitExceeded: "Key Rate Limit Exceeded", RateLimitSmoothingUp: "Rate limit increased with smoothing", RateLimitSmoothingDown: "Rate limit decreased with smoothing", } diff --git a/internal/event/event_test.go b/internal/event/event_test.go index 16a59f79a61..1a42655b8d5 100644 --- a/internal/event/event_test.go +++ b/internal/event/event_test.go @@ -14,7 +14,7 @@ func TestEventToString(t *testing.T) { t.Run("Event with description", func(t *testing.T) { t.Parallel() - s := String(RateLimitExceeded) + s := String(RateLimitSmoothingUp) assert.NotEmpty(t, s) assert.Contains(t, s, " ") }) diff --git a/internal/policy/Taskfile.yml b/internal/policy/Taskfile.yml new file mode 100644 index 00000000000..8f5cf09cff2 --- /dev/null +++ b/internal/policy/Taskfile.yml @@ -0,0 +1,40 @@ +--- +version: "3" + +includes: + services: + taskfile: ../../docker/services/Taskfile.yml + dir: ../../docker/services + +vars: + run: . + +tasks: + default: + desc: "Run tests" + deps: [ services:up ] + requires: + vars: [run] + cmds: + - defer: { task: services:down } + - goimports -w . + - go fmt ./... + - go test -count=1 -run='({{.run}})' -cover -coverprofile=pkg.cov -v . + + cover: + desc: "Show source coverage" + aliases: [coverage, cov] + cmds: + - go tool cover -func=pkg.cov + + uncover: + desc: "Show uncovered source" + cmds: + - uncover pkg.cov + + install:uncover: + desc: "Install uncover" + env: + GOBIN: /usr/local/bin + cmds: + - go install github.com/gregoryv/uncover/...@latest diff --git a/internal/policy/apply.go b/internal/policy/apply.go new file mode 100644 index 00000000000..192be005726 --- /dev/null +++ b/internal/policy/apply.go @@ -0,0 +1,478 @@ +package policy + +import ( + "errors" + "fmt" + + "github.com/sirupsen/logrus" + + "github.com/TykTechnologies/tyk/user" +) + +// Repository is a storage encapsulating policy retrieval. +// Gateway implements this object to decouple this package. +type Repository interface { + PolicyCount() int + PolicyIDs() []string + PolicyByID(string) (user.Policy, bool) +} + +type Service struct { + storage Repository + logger *logrus.Logger + + // used for validation if not empty + orgID *string +} + +func New(orgID *string, storage Repository, logger *logrus.Logger) *Service { + return &Service{ + orgID: orgID, + storage: storage, + logger: logger, + } +} + +// ClearSession clears the quota, rate limit and complexity values so that partitioned policies can apply their values. +// Otherwise, if the session has already a higher value, an applied policy will not win, and its values will be ignored. +func (t *Service) ClearSession(session *user.SessionState) error { + policies := session.PolicyIDs() + + for _, polID := range policies { + policy, ok := t.storage.PolicyByID(polID) + if !ok { + return fmt.Errorf("policy not found: %s", polID) + } + + all := !(policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) + + if policy.Partitions.Quota || all { + session.QuotaMax = 0 + session.QuotaRemaining = 0 + } + + if policy.Partitions.RateLimit || all { + session.Rate = 0 + session.Per = 0 + session.Smoothing = nil + session.ThrottleRetryLimit = 0 + session.ThrottleInterval = 0 + } + + if policy.Partitions.Complexity || all { + session.MaxQueryDepth = 0 + } + } + + return nil +} + +// ApplyPolicies will check if any policies are loaded. If any are, it +// will overwrite the session state to use the policy values. +func (t *Service) Apply(session *user.SessionState) error { + rights := make(map[string]user.AccessDefinition) + tags := make(map[string]bool) + if session.MetaData == nil { + session.MetaData = make(map[string]interface{}) + } + + if err := t.ClearSession(session); err != nil { + t.logger.WithError(err).Warn("error clearing session") + } + + didQuota, didRateLimit, didACL, didComplexity := make(map[string]bool), make(map[string]bool), make(map[string]bool), make(map[string]bool) + + var ( + err error + policyIDs []string + ) + + storage := t.storage + customPolicies, err := session.CustomPolicies() + if err != nil { + policyIDs = session.PolicyIDs() + } else { + storage = NewStore(customPolicies) + policyIDs = storage.PolicyIDs() + } + + for _, polID := range policyIDs { + policy, ok := storage.PolicyByID(polID) + if !ok { + err := fmt.Errorf("policy not found: %q", polID) + t.Logger().Error(err) + if len(policyIDs) > 1 { + continue + } + + return err + } + // Check ownership, policy org owner must be the same as API, + // otherwise you could overwrite a session key with a policy from a different org! + if t.orgID != nil && policy.OrgID != *t.orgID { + err := errors.New("attempting to apply policy from different organisation to key, skipping") + t.Logger().Error(err) + return err + } + + if policy.Partitions.PerAPI && + (policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) { + err := fmt.Errorf("cannot apply policy %s which has per_api and any of partitions set", policy.ID) + t.logger.Error(err) + return err + } + + if policy.Partitions.PerAPI { + for apiID, accessRights := range policy.AccessRights { + // new logic when you can specify quota or rate in more than one policy but for different APIs + if didQuota[apiID] || didRateLimit[apiID] || didACL[apiID] || didComplexity[apiID] { // no other partitions allowed + err := fmt.Errorf("cannot apply multiple policies when some have per_api set and some are partitioned") + t.logger.Error(err) + return err + } + + idForScope := apiID + // check if we don't have limit on API level specified when policy was created + if accessRights.Limit.IsEmpty() { + // limit was not specified on API level so we will populate it from policy + idForScope = policy.ID + accessRights.Limit = policy.APILimit() + } + accessRights.AllowanceScope = idForScope + accessRights.Limit.SetBy = idForScope + + // respect current quota renews (on API limit level) + if r, ok := session.AccessRights[apiID]; ok && !r.Limit.IsEmpty() { + accessRights.Limit.QuotaRenews = r.Limit.QuotaRenews + } + + if r, ok := session.AccessRights[apiID]; ok { + // If GQL introspection is disabled, keep that configuration. + if r.DisableIntrospection { + accessRights.DisableIntrospection = r.DisableIntrospection + } + } + + // overwrite session access right for this API + rights[apiID] = accessRights + + // identify that limit for that API is set (to allow set it only once) + didACL[apiID] = true + didQuota[apiID] = true + didRateLimit[apiID] = true + didComplexity[apiID] = true + } + } else { + usePartitions := policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity + + for k, v := range policy.AccessRights { + ar := v + + if !usePartitions || policy.Partitions.Acl { + didACL[k] = true + + ar.AllowedURLs = copyAllowedURLs(v.AllowedURLs) + + // Merge ACLs for the same API + if r, ok := rights[k]; ok { + // If GQL introspection is disabled, keep that configuration. + if v.DisableIntrospection { + r.DisableIntrospection = v.DisableIntrospection + } + r.Versions = appendIfMissing(rights[k].Versions, v.Versions...) + + for _, u := range v.AllowedURLs { + found := false + for ai, au := range r.AllowedURLs { + if u.URL == au.URL { + found = true + r.AllowedURLs[ai].Methods = appendIfMissing(au.Methods, u.Methods...) + } + } + + if !found { + r.AllowedURLs = append(r.AllowedURLs, v.AllowedURLs...) + } + } + + for _, t := range v.RestrictedTypes { + for ri, rt := range r.RestrictedTypes { + if t.Name == rt.Name { + r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } + } + } + + for _, t := range v.AllowedTypes { + for ri, rt := range r.AllowedTypes { + if t.Name == rt.Name { + r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } + } + } + + mergeFieldLimits := func(res *user.FieldLimits, new user.FieldLimits) { + if greaterThanInt(new.MaxQueryDepth, res.MaxQueryDepth) { + res.MaxQueryDepth = new.MaxQueryDepth + } + } + + for _, far := range v.FieldAccessRights { + exists := false + for i, rfar := range r.FieldAccessRights { + if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { + exists = true + mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) + } + } + + if !exists { + r.FieldAccessRights = append(r.FieldAccessRights, far) + } + } + + ar = r + } + + ar.Limit.SetBy = policy.ID + } + + if !usePartitions || policy.Partitions.Quota { + didQuota[k] = true + if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { + + ar.Limit.QuotaMax = policy.QuotaMax + if greaterThanInt64(policy.QuotaMax, session.QuotaMax) { + session.QuotaMax = policy.QuotaMax + } + } + + if policy.QuotaRenewalRate > ar.Limit.QuotaRenewalRate { + ar.Limit.QuotaRenewalRate = policy.QuotaRenewalRate + if policy.QuotaRenewalRate > session.QuotaRenewalRate { + session.QuotaRenewalRate = policy.QuotaRenewalRate + } + } + } + + if !usePartitions || policy.Partitions.RateLimit { + didRateLimit[k] = true + + t.ApplyRateLimits(session, policy, &ar.Limit) + + if policy.ThrottleRetryLimit > ar.Limit.ThrottleRetryLimit { + ar.Limit.ThrottleRetryLimit = policy.ThrottleRetryLimit + if policy.ThrottleRetryLimit > session.ThrottleRetryLimit { + session.ThrottleRetryLimit = policy.ThrottleRetryLimit + } + } + + if policy.ThrottleInterval > ar.Limit.ThrottleInterval { + ar.Limit.ThrottleInterval = policy.ThrottleInterval + if policy.ThrottleInterval > session.ThrottleInterval { + session.ThrottleInterval = policy.ThrottleInterval + } + } + } + + if !usePartitions || policy.Partitions.Complexity { + didComplexity[k] = true + + if greaterThanInt(policy.MaxQueryDepth, ar.Limit.MaxQueryDepth) { + ar.Limit.MaxQueryDepth = policy.MaxQueryDepth + if greaterThanInt(policy.MaxQueryDepth, session.MaxQueryDepth) { + session.MaxQueryDepth = policy.MaxQueryDepth + } + } + } + + // Respect existing QuotaRenews + if r, ok := session.AccessRights[k]; ok && !r.Limit.IsEmpty() { + ar.Limit.QuotaRenews = r.Limit.QuotaRenews + } + + rights[k] = ar + } + + // Master policy case + if len(policy.AccessRights) == 0 { + if !usePartitions || policy.Partitions.RateLimit { + session.Rate = policy.Rate + session.Per = policy.Per + session.Smoothing = policy.Smoothing + session.ThrottleInterval = policy.ThrottleInterval + session.ThrottleRetryLimit = policy.ThrottleRetryLimit + } + + if !usePartitions || policy.Partitions.Complexity { + session.MaxQueryDepth = policy.MaxQueryDepth + } + + if !usePartitions || policy.Partitions.Quota { + session.QuotaMax = policy.QuotaMax + session.QuotaRenewalRate = policy.QuotaRenewalRate + } + } + + if !session.HMACEnabled { + session.HMACEnabled = policy.HMACEnabled + } + + if !session.EnableHTTPSignatureValidation { + session.EnableHTTPSignatureValidation = policy.EnableHTTPSignatureValidation + } + } + + session.IsInactive = session.IsInactive || policy.IsInactive + + for _, tag := range policy.Tags { + tags[tag] = true + } + + for k, v := range policy.MetaData { + session.MetaData[k] = v + } + + if policy.LastUpdated > session.LastUpdated { + session.LastUpdated = policy.LastUpdated + } + } + + for _, tag := range session.Tags { + tags[tag] = true + } + + // set tags + session.Tags = []string{} + for tag := range tags { + session.Tags = appendIfMissing(session.Tags, tag) + } + + if len(policyIDs) == 0 { + for apiID, accessRight := range session.AccessRights { + // check if the api in the session has per api limit + if !accessRight.Limit.IsEmpty() { + accessRight.AllowanceScope = apiID + session.AccessRights[apiID] = accessRight + } + } + } + + distinctACL := make(map[string]bool) + + for _, v := range rights { + if v.Limit.SetBy != "" { + distinctACL[v.Limit.SetBy] = true + } + } + + // If some APIs had only ACL partitions, inherit rest from session level + for k, v := range rights { + if !didACL[k] { + delete(rights, k) + continue + } + + if !didRateLimit[k] { + v.Limit.Rate = session.Rate + v.Limit.Per = session.Per + v.Limit.Smoothing = session.Smoothing + v.Limit.ThrottleInterval = session.ThrottleInterval + v.Limit.ThrottleRetryLimit = session.ThrottleRetryLimit + } + + if !didComplexity[k] { + v.Limit.MaxQueryDepth = session.MaxQueryDepth + } + + if !didQuota[k] { + v.Limit.QuotaMax = session.QuotaMax + v.Limit.QuotaRenewalRate = session.QuotaRenewalRate + v.Limit.QuotaRenews = session.QuotaRenews + } + + // If multime ACL + if len(distinctACL) > 1 { + if v.AllowanceScope == "" && v.Limit.SetBy != "" { + v.AllowanceScope = v.Limit.SetBy + } + } + + v.Limit.SetBy = "" + + rights[k] = v + } + + // If we have policies defining rules for one single API, update session root vars (legacy) + if len(didQuota) == 1 && len(didRateLimit) == 1 && len(didComplexity) == 1 { + for _, v := range rights { + if len(didRateLimit) == 1 { + session.Rate = v.Limit.Rate + session.Per = v.Limit.Per + session.Smoothing = v.Limit.Smoothing + } + + if len(didQuota) == 1 { + session.QuotaMax = v.Limit.QuotaMax + session.QuotaRenews = v.Limit.QuotaRenews + session.QuotaRenewalRate = v.Limit.QuotaRenewalRate + } + + if len(didComplexity) == 1 { + session.MaxQueryDepth = v.Limit.MaxQueryDepth + } + } + } + + // Override session ACL if at least one policy define it + if len(didACL) > 0 { + session.AccessRights = rights + } + + return nil +} + +func (t *Service) Logger() *logrus.Logger { + return t.logger +} + +// ApplyRateLimits will write policy limits to session and apiLimits. +// The limits get written if either are empty. +// The limits get written if filled and policyLimits allows a higher request rate. +func (t *Service) ApplyRateLimits(session *user.SessionState, policy user.Policy, apiLimits *user.APILimit) { + policyLimits := policy.APILimit() + if t.emptyRateLimit(policyLimits) { + return + } + + // duration is time between requests, e.g.: + // + // apiLimits: 500ms for 2 requests / second + // policyLimits: 100ms for 10 requests / second + // + // if apiLimits > policyLimits (500ms > 100ms) then + // we apply the higher rate from the policy. + // + // the policy-defined rate limits are enforced as + // a minimum possible api rate limit setting, + // raising apiLimits. + + if t.emptyRateLimit(*apiLimits) || apiLimits.Duration() > policyLimits.Duration() { + apiLimits.Rate = policyLimits.Rate + apiLimits.Per = policyLimits.Per + apiLimits.Smoothing = policyLimits.Smoothing + } + + // sessionLimits, similar to apiLimits, get policy + // rate applied if the policy allows more requests. + sessionLimits := session.APILimit() + if t.emptyRateLimit(sessionLimits) || sessionLimits.Duration() > policyLimits.Duration() { + session.Rate = policyLimits.Rate + session.Per = policyLimits.Per + session.Smoothing = policyLimits.Smoothing + } +} + +func (t *Service) emptyRateLimit(m user.APILimit) bool { + return m.Rate == 0 || m.Per == 0 +} diff --git a/internal/policy/apply_test.go b/internal/policy/apply_test.go new file mode 100644 index 00000000000..1e98c2cc32e --- /dev/null +++ b/internal/policy/apply_test.go @@ -0,0 +1,124 @@ +package policy_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +func TestApplyRateLimits_PolicyLimits(t *testing.T) { + svc := &policy.Service{} + + t.Run("policy limits unset", func(t *testing.T) { + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + Rate: 10, + Per: 10, + } + policy := user.Policy{} + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 5, int(session.Rate)) + }) + + t.Run("policy limits apply all", func(t *testing.T) { + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + Rate: 5, + Per: 10, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 10, int(session.Rate)) + }) + + // As the policy defined a higher rate than apiLimits, + // changes are applied to api limits, but skipped on + // the session as the session has a higher allowance. + t.Run("policy limits apply per-api", func(t *testing.T) { + session := &user.SessionState{ + Rate: 15, + Per: 10, + } + apiLimits := user.APILimit{ + Rate: 5, + Per: 10, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 15, int(session.Rate)) + }) + + // As the policy defined a lower rate than apiLimits, + // no changes to api limits are applied. + t.Run("policy limits skip", func(t *testing.T) { + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + Rate: 15, + Per: 10, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 15, int(apiLimits.Rate)) + assert.Equal(t, 10, int(session.Rate)) + }) +} + +func TestApplyRateLimits_FromCustomPolicies(t *testing.T) { + svc := &policy.Service{} + + t.Run("Custom policies", func(t *testing.T) { + session := &user.SessionState{} + session.SetCustomPolicies([]user.Policy{ + { + ID: "pol1", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 8, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, + { + ID: "pol2", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 10, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, + }) + + svc.Apply(session) + + assert.Equal(t, 10, int(session.Rate)) + }) +} diff --git a/internal/policy/store.go b/internal/policy/store.go new file mode 100644 index 00000000000..909c53e8bca --- /dev/null +++ b/internal/policy/store.go @@ -0,0 +1,35 @@ +package policy + +import ( + "github.com/TykTechnologies/tyk/user" +) + +// Store is an in-memory policy storage object that +// implements the repository for policy access. We +// do not implement concurrency protections here. +type Store struct { + policies map[string]user.Policy +} + +func NewStore(policies map[string]user.Policy) *Store { + return &Store{ + policies: policies, + } +} + +func (s *Store) PolicyIDs() []string { + policyIDs := make([]string, 0, len(s.policies)) + for _, val := range s.policies { + policyIDs = append(policyIDs, val.ID) + } + return policyIDs +} + +func (s *Store) PolicyByID(id string) (user.Policy, bool) { + v, ok := s.policies[id] + return v, ok +} + +func (s *Store) PolicyCount() int { + return len(s.policies) +} diff --git a/internal/policy/util.go b/internal/policy/util.go new file mode 100644 index 00000000000..ed34211c0f4 --- /dev/null +++ b/internal/policy/util.go @@ -0,0 +1,129 @@ +package policy + +import ( + "github.com/TykTechnologies/tyk/user" +) + +// appendIfMissing ensures dest slice is unique with new items. +func appendIfMissing(src []string, in ...string) []string { + // Use map for uniqueness + srcMap := map[string]bool{} + for _, v := range src { + srcMap[v] = true + } + for _, v := range in { + srcMap[v] = true + } + + // Produce unique []string, maintain sort order + uniqueSorted := func(src []string, keys map[string]bool) []string { + result := make([]string, 0, len(keys)) + for _, v := range src { + // append missing value + if val := keys[v]; val { + result = append(result, v) + delete(keys, v) + } + } + return result + } + + // no new items from `in` + if len(srcMap) == len(src) { + return src + } + + src = uniqueSorted(src, srcMap) + in = uniqueSorted(in, srcMap) + + return append(src, in...) +} + +// intersection gets intersection of the given two slices. +func intersection(a []string, b []string) (inter []string) { + m := make(map[string]bool) + + for _, item := range a { + m[item] = true + } + + for _, item := range b { + if _, ok := m[item]; ok { + inter = append(inter, item) + } + } + + return +} + +// contains checks whether the given slice contains the given item. +func contains(s []string, i string) bool { + for _, a := range s { + if a == i { + return true + } + } + return false +} + +// greaterThanFloat64 checks whether first float64 value is bigger than second float64 value. +// -1 means infinite and the biggest value. +func greaterThanFloat64(first, second float64) bool { + if first == -1 { + return true + } + + if second == -1 { + return false + } + + return first > second +} + +// greaterThanInt64 checks whether first int64 value is bigger than second int64 value. +// -1 means infinite and the biggest value. +func greaterThanInt64(first, second int64) bool { + if first == -1 { + return true + } + + if second == -1 { + return false + } + + return first > second +} + +// greaterThanInt checks whether first int value is bigger than second int value. +// -1 means infinite and the biggest value. +func greaterThanInt(first, second int) bool { + if first == -1 { + return true + } + + if second == -1 { + return false + } + + return first > second +} + +func copyAllowedURLs(input []user.AccessSpec) []user.AccessSpec { + if input == nil { + return nil + } + + copied := make([]user.AccessSpec, len(input)) + + for i, as := range input { + copied[i] = user.AccessSpec{ + URL: as.URL, + } + if as.Methods != nil { + copied[i].Methods = make([]string, len(as.Methods)) + copy(copied[i].Methods, as.Methods) + } + } + + return copied +} diff --git a/user/session.go b/user/session.go index 06abcf746b0..c7d342aa3e1 100644 --- a/user/session.go +++ b/user/session.go @@ -55,18 +55,13 @@ type APILimit struct { Smoothing *apidef.RateLimitSmoothing `json:"smoothing" bson:"smoothing"` } -// Less will return true if the receiver has a smaller duration between requests than `in`. -func (g *APILimit) Less(in APILimit) bool { - return g.Duration() < in.Duration() -} - // Duration returns the time between two allowed requests at the defined rate. // It's used to decide which rate limit has a bigger allowance. func (g *APILimit) Duration() time.Duration { if g.Per <= 0 || g.Rate <= 0 { return 0 } - return time.Second * time.Duration(g.Rate/g.Per) + return time.Duration(float64(time.Second) * g.Per / g.Rate) } // AccessDefinition defines which versions of an API a key has access to diff --git a/user/session_test.go b/user/session_test.go index a2b0f6b5a92..276650752d8 100644 --- a/user/session_test.go +++ b/user/session_test.go @@ -140,51 +140,13 @@ func Test_calculateLifetime(t *testing.T) { }) } -func TestAPILimit_Less(t *testing.T) { - t.Run("limit1 less than limit2", func(t *testing.T) { - limit1 := APILimit{ - Rate: 1, - Per: 2, - } - limit2 := APILimit{ - Rate: 2, - Per: 2, - } - assert.True(t, limit1.Less(limit2)) - }) - - t.Run("limit1 equal to limit2", func(t *testing.T) { - limit1 := APILimit{ - Rate: 1, - Per: 1, - } - limit2 := APILimit{ - Rate: 1, - Per: 1, - } - assert.False(t, limit1.Less(limit2)) - }) - - t.Run("limit1 greater than limit2", func(t *testing.T) { - limit1 := APILimit{ - Rate: 3, - Per: 1, - } - limit2 := APILimit{ - Rate: 1, - Per: 1, - } - assert.False(t, limit1.Less(limit2)) - }) -} - func TestAPILimit_Duration(t *testing.T) { t.Run("valid limit", func(t *testing.T) { limit := APILimit{ Rate: 1, Per: 2, } - expectedDuration := time.Second * time.Duration(limit.Rate/limit.Per) + expectedDuration := 2 * time.Second assert.Equal(t, expectedDuration, limit.Duration()) })