Skip to content

Commit

Permalink
Fix concurrent session writing release2.9 {do not merge, checking tes…
Browse files Browse the repository at this point in the history
…ts} (#3300)

<!-- Provide a general summary of your changes in the Title above -->

## Description
backport #3274  to 2.9

## Related Issue
None, just backporting #3274

## Motivation and Context


## How This Has Been Tested

## Screenshots (if appropriate)

## Types of changes
<!-- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Refactoring or add test (improvements in base code or adds test coverage to functionality)

## Checklist
<!-- Go over all the following points, and put an `x` in all the boxes that apply -->
<!-- If you're unsure about any of these, don't hesitate to ask; we're here to help! -->
- [ ] Make sure you are requesting to **pull a topic/feature/bugfix branch** (right side). If pulling from your own
      fork, don't request your `master`!
- [ ] Make sure you are making a pull request against the **`master` branch** (left side). Also, you should start
      *your branch* off *our latest `master`*.
- [ ] My change requires a change to the documentation.
  - [ ] If you've changed APIs, describe what needs to be updated in the documentation.
  - [ ] If new config option added, ensure that it can be set via ENV variable
- [ ] I have updated the documentation accordingly.
- [ ] Modules and vendor dependencies have been updated; run `go mod tidy && go mod vendor`
- [ ] When updating library version must provide reason/explanation for this update.
- [ ] I have added tests to cover my changes.
- [ ] All new and existing tests passed.
- [ ] Check your code additions will not fail linting checks:
  - [ ] `go fmt -s`
  - [ ] `go vet`
  • Loading branch information
sredxny committed Sep 3, 2020
1 parent e5b78e9 commit b5187a8
Show file tree
Hide file tree
Showing 42 changed files with 338 additions and 148 deletions.
3 changes: 3 additions & 0 deletions cli/linter/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@
"drl_threshold": {
"type": "number"
},
"drl_enable_sentinel_rate_limiter": {
"type":"boolean"
},
"enable_analytics": {
"type": "boolean"
},
Expand Down
14 changes: 9 additions & 5 deletions coprocess/grpc/coprocess_grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"testing"

"context"
Expand Down Expand Up @@ -73,7 +74,7 @@ func (d *dispatcher) Dispatch(ctx context.Context, object *coprocess.Object) (*c
return d.grpcError(object, "Request content type should be either JSON or multipart")
}
case "testPostHook1":
testKeyValue, ok := object.Session.Metadata["testkey"]
testKeyValue, ok := object.Session.GetMetadata()["testkey"]
if !ok {
return d.grpcError(object, "'testkey' not found in session metadata")
}
Expand All @@ -88,17 +89,17 @@ func (d *dispatcher) Dispatch(ctx context.Context, object *coprocess.Object) (*c
if nestedKeyValue != "nestedvalue" {
return d.grpcError(object, "'nestedvalue' value doesn't match")
}
testKey2Value, ok := object.Session.Metadata["testkey2"]
testKey2Value, ok := object.Session.GetMetadata()["testkey2"]
if !ok {
return d.grpcError(object, "'testkey' not found in session metadata")
}
if testKey2Value != "testvalue" {
return d.grpcError(object, "'testkey2' value doesn't match")
}

// Check for compatibility (object.Metadata should contain the same keys as object.Session.Metadata)
// Check for compatibility (object.Metadata should contain the same keys as object.Session.GetMetadata())
for k, v := range object.Metadata {
sessionKeyValue, ok := object.Session.Metadata[k]
sessionKeyValue, ok := object.Session.GetMetadata()[k]
if !ok {
return d.grpcError(object, k+" not found in object.Session.Metadata")
}
Expand Down Expand Up @@ -242,6 +243,7 @@ func TestGRPCDispatch(t *testing.T) {
"testkey": map[string]interface{}{"nestedkey": "nestedvalue"},
"testkey2": "testvalue",
}
s.Mutex = &sync.RWMutex{}
})
headers := map[string]string{"authorization": keyID}

Expand Down Expand Up @@ -341,7 +343,9 @@ func BenchmarkGRPCDispatch(b *testing.B) {
defer ts.Close()
defer grpcServer.Stop()

keyID := gateway.CreateSession(func(s *user.SessionState) {})
keyID := gateway.CreateSession(func(s *user.SessionState) {
s.Mutex = &sync.RWMutex{}
})
headers := map[string]string{"authorization": keyID}

b.Run("Pre Hook with SetHeaders", func(b *testing.B) {
Expand Down
2 changes: 2 additions & 0 deletions coprocess/python/coprocess_python_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -255,6 +256,7 @@ func TestPythonBundles(t *testing.T) {
"testkey": map[string]interface{}{"nestedkey": "nestedvalue"},
"stringkey": "testvalue",
}
s.Mutex = &sync.RWMutex{}
})

gateway.BuildAndLoadAPI(func(spec *gateway.APISpec) {
Expand Down
45 changes: 31 additions & 14 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func getApisIdsForOrg(orgID string) []string {
func checkAndApplyTrialPeriod(keyName, apiId string, newSession *user.SessionState, isHashed bool) {

// Check the policies to see if we are forcing an expiry on the key
for _, polID := range newSession.PolicyIDs() {
for _, polID := range newSession.GetPolicyIDs() {
policiesMu.RLock()
policy, ok := policiesByID[polID]
policiesMu.RUnlock()
Expand Down Expand Up @@ -228,11 +228,11 @@ func doAddOrUpdate(keyName string, newSession *user.SessionState, dontReset bool
newSession.LastUpdated = strconv.Itoa(int(time.Now().Unix()))
}

if len(newSession.AccessRights) > 0 {
if len(newSession.GetAccessRights()) > 0 {
// reset API-level limit to nil if any has a zero-value
resetAPILimits(newSession.AccessRights)
// We have a specific list of access rules, only add / update those
for apiId := range newSession.AccessRights {
for apiId := range newSession.GetAccessRights() {
apiSpec := getApiSpec(apiId)
if apiSpec == nil {
log.WithFields(logrus.Fields{
Expand Down Expand Up @@ -330,7 +330,9 @@ func handleAddOrUpdate(keyName string, r *http.Request, isHashed bool) (interfac
suppressReset := r.URL.Query().Get("suppress_reset") == "1"

// decode payload
newSession := user.SessionState{}
newSession := user.SessionState{
Mutex: &sync.RWMutex{},
}

contents, _ := ioutil.ReadAll(r.Body)
r.Body = ioutil.NopCloser(bytes.NewReader(contents))
Expand All @@ -346,7 +348,9 @@ func handleAddOrUpdate(keyName string, r *http.Request, isHashed bool) (interfac
// DO ADD OR UPDATE

// get original session in case of update and preserve fields that SHOULD NOT be updated
originalKey := user.SessionState{}
originalKey := user.SessionState{
Mutex: &sync.RWMutex{},
}
if r.Method == http.MethodPut {
found := false
for apiID := range newSession.AccessRights {
Expand All @@ -372,11 +376,11 @@ func handleAddOrUpdate(keyName string, r *http.Request, isHashed bool) (interfac
newSession.LastUpdated = originalKey.LastUpdated

// on ACL API limit level
for apiID, access := range originalKey.AccessRights {
for apiID, access := range originalKey.GetAccessRights() {
if access.Limit == nil {
continue
}
if newAccess, ok := newSession.AccessRights[apiID]; ok && newAccess.Limit != nil {
if newAccess, ok := newSession.GetAccessRightByAPIID(apiID); ok && newAccess.Limit != nil {
newAccess.Limit.QuotaRenews = access.Limit.QuotaRenews
newSession.AccessRights[apiID] = newAccess
}
Expand Down Expand Up @@ -505,7 +509,7 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) {
}

// populate remaining quota for API limits (if any)
for id, access := range session.AccessRights {
for id, access := range session.GetAccessRights() {
if access.Limit == nil || access.Limit.QuotaMax == -1 || access.Limit.QuotaMax == 0 {
continue
}
Expand Down Expand Up @@ -596,7 +600,10 @@ func handleAddKey(keyName, hashedName, sessionString, apiID string) {
SessionManager: FallbackKeySesionManager,
},
}
sess := user.SessionState{}
sess := user.SessionState{
Mutex: &sync.RWMutex{},
}

json.Unmarshal([]byte(sessionString), &sess)
sess.LastUpdated = strconv.Itoa(int(time.Now().Unix()))
var err error
Expand Down Expand Up @@ -624,13 +631,20 @@ func handleDeleteKey(keyName, apiID string, resetQuota bool) (interface{}, int)
if apiID == "-1" {
// Go through ALL managed API's and delete the key
apisMu.RLock()

removed := false
for _, spec := range apisByID {
if spec.SessionManager.RemoveSession(keyName, false) {
removed = true
}
spec.SessionManager.ResetQuota(keyName, &user.SessionState{}, false)
spec.SessionManager.ResetQuota(
keyName,
&user.SessionState{
Mutex: &sync.RWMutex{},
},
false)
}

apisMu.RUnlock()

if !removed {
Expand Down Expand Up @@ -668,7 +682,7 @@ func handleDeleteKey(keyName, apiID string, resetQuota bool) (interface{}, int)
}

if resetQuota {
sessionManager.ResetQuota(keyName, &user.SessionState{}, false)
sessionManager.ResetQuota(keyName, &user.SessionState{Mutex: &sync.RWMutex{}}, false)
}

statusObj := apiModifyKeySuccess{
Expand Down Expand Up @@ -742,7 +756,7 @@ func handleDeleteHashedKey(keyName, apiID string, resetQuota bool) (interface{},
}

if resetQuota {
sessionManager.ResetQuota(keyName, &user.SessionState{}, true)
sessionManager.ResetQuota(keyName, &user.SessionState{Mutex: &sync.RWMutex{}}, true)
}

statusObj := apiModifyKeySuccess{
Expand Down Expand Up @@ -1075,6 +1089,7 @@ func orgHandler(w http.ResponseWriter, r *http.Request) {

func handleOrgAddOrUpdate(keyName string, r *http.Request) (interface{}, int) {
newSession := new(user.SessionState)
newSession.Mutex = &sync.RWMutex{}

if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
log.Error("Couldn't decode new session object: ", err)
Expand Down Expand Up @@ -1258,6 +1273,7 @@ func resetHandler(fn func()) http.HandlerFunc {

func createKeyHandler(w http.ResponseWriter, r *http.Request) {
newSession := new(user.SessionState)
newSession.Mutex = &sync.RWMutex{}
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
log.WithFields(logrus.Fields{
"prefix": "api",
Expand Down Expand Up @@ -1288,10 +1304,10 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
mw := BaseMiddleware{}
mw.ApplyPolicies(newSession)

if len(newSession.AccessRights) > 0 {
if len(newSession.GetAccessRights()) > 0 {
// reset API-level limit to nil if any has a zero-value
resetAPILimits(newSession.AccessRights)
for apiID := range newSession.AccessRights {
for apiID := range newSession.GetAccessRights() {
apiSpec := getApiSpec(apiID)
if apiSpec != nil {
checkAndApplyTrialPeriod(newKey, apiID, newSession, false)
Expand Down Expand Up @@ -1400,6 +1416,7 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {

func previewKeyHandler(w http.ResponseWriter, r *http.Request) {
newSession := new(user.SessionState)
newSession.Mutex = &sync.RWMutex{}
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
log.WithFields(logrus.Fields{
"prefix": "api",
Expand Down
1 change: 1 addition & 0 deletions gateway/api_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ func testPrepareDefaultVersion() string {
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1", "v2"},
}}
s.Mutex = &sync.RWMutex{}
})
}

Expand Down
25 changes: 18 additions & 7 deletions gateway/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,15 @@ func TestKeyHandler(t *testing.T) {
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1"},
}}
s.Mutex = &sync.RWMutex{}
})

_, unknownOrgKey := ts.CreateSession(func(s *user.SessionState) {
s.OrgID = "dummy"
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1"},
}}
s.Mutex = &sync.RWMutex{}
})

t.Run("Get key", func(t *testing.T) {
Expand All @@ -274,6 +276,7 @@ func TestKeyHandler(t *testing.T) {
s.AccessRights = map[string]user.AccessDefinition{"test": {
APIID: "test", Versions: []string{"v1"},
}}
s.Mutex = &sync.RWMutex{}
})

assert := func(response *http.Response, expected []string) {
Expand Down Expand Up @@ -359,6 +362,7 @@ func TestKeyHandler_UpdateKey(t *testing.T) {
s.AccessRights = map[string]user.AccessDefinition{testAPIID: {
APIID: testAPIID, Versions: []string{"v1"},
}}
s.Mutex = &sync.RWMutex{}
})

t.Run("Add policy not enforcing acl", func(t *testing.T) {
Expand All @@ -371,7 +375,8 @@ func TestKeyHandler_UpdateKey(t *testing.T) {
}...)

sessionState, found := FallbackKeySesionManager.SessionDetail(key, false)
if !found || sessionState.AccessRights[testAPIID].APIID != testAPIID || len(sessionState.ApplyPolicies) != 2 {
accessRight, _ := sessionState.GetAccessRightByAPIID(testAPIID)
if !found || accessRight.APIID != testAPIID || len(sessionState.ApplyPolicies) != 2 {
t.Fatal("Adding policy to the list failed")
}
})
Expand All @@ -386,7 +391,8 @@ func TestKeyHandler_UpdateKey(t *testing.T) {
}...)

sessionState, found := FallbackKeySesionManager.SessionDetail(key, false)
if !found || sessionState.AccessRights[testAPIID].APIID != testAPIID || len(sessionState.ApplyPolicies) != 0 {
accessRight, _ := sessionState.GetAccessRightByAPIID(testAPIID)
if !found || accessRight.APIID != testAPIID || len(sessionState.ApplyPolicies) != 0 {
t.Fatal("Removing policy from the list failed")
}
})
Expand Down Expand Up @@ -443,8 +449,8 @@ func TestKeyHandler_UpdateKey(t *testing.T) {

sessionState, found := FallbackKeySesionManager.SessionDetail(key, false)

if !found || !reflect.DeepEqual(expected, sessionState.MetaData) {
t.Fatalf("Expected %v, returned %v", expected, sessionState.MetaData)
if !found || !reflect.DeepEqual(expected, sessionState.GetMetaData()) {
t.Fatalf("Expected %v, returned %v", expected, sessionState.GetMetaData())
}
}

Expand Down Expand Up @@ -477,9 +483,9 @@ func TestKeyHandler_UpdateKey(t *testing.T) {
"key-meta2": "key-value2",
}
session.ApplyPolicies = []string{pID, pID2}
session.MetaData = map[string]interface{}{
session.SetMetaData(map[string]interface{}{
"key-meta2": "key-value2",
}
})
assertMetaData(session, expected)
})
})
Expand Down Expand Up @@ -1360,7 +1366,12 @@ func TestContextSession(t *testing.T) {
if ctxGetSession(r) != nil {
t.Fatal("expected ctxGetSession to return nil")
}
ctxSetSession(r, &user.SessionState{}, "", false)
ctxSetSession(r,
&user.SessionState{
Mutex: &sync.RWMutex{},
},
"",
false)
if ctxGetSession(r) == nil {
t.Fatal("expected ctxGetSession to return non-nil")
}
Expand Down
2 changes: 1 addition & 1 deletion gateway/auth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (b *DefaultSessionManager) ResetQuota(keyName string, session *user.Session
go b.store.DeleteRawKey(rawKey)
//go b.store.SetKey(rawKey, "0", session.QuotaRenewalRate)

for _, acl := range session.AccessRights {
for _, acl := range session.GetAccessRights() {
rawKey = QuotaKeyPrefix + acl.AllowanceScope + "-" + keyName
go b.store.DeleteRawKey(rawKey)
}
Expand Down
14 changes: 8 additions & 6 deletions gateway/auth_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"net/http"
"sync"
"testing"

"github.com/TykTechnologies/tyk/storage"
Expand All @@ -26,9 +27,10 @@ func TestAuthenticationAfterDeleteKey(t *testing.T) {
})[0]

key := CreateSession(func(s *user.SessionState) {
s.AccessRights = map[string]user.AccessDefinition{api.APIID: {
s.SetAccessRights(map[string]user.AccessDefinition{api.APIID: {
APIID: api.APIID,
}}
}})
s.Mutex = &sync.RWMutex{}
})
deletePath := "/tyk/keys/" + key
authHeader := map[string]string{
Expand Down Expand Up @@ -68,9 +70,9 @@ func TestAuthenticationAfterUpdateKey(t *testing.T) {
key := generateToken("", "")

session := CreateStandardSession()
session.AccessRights = map[string]user.AccessDefinition{api.APIID: {
session.SetAccessRights(map[string]user.AccessDefinition{api.APIID: {
APIID: api.APIID,
}}
}})

FallbackKeySesionManager.UpdateSession(storage.HashKey(key), session, 0, config.Global().HashKeys)

Expand All @@ -82,9 +84,9 @@ func TestAuthenticationAfterUpdateKey(t *testing.T) {
{Path: "/get", Headers: authHeader, Code: http.StatusOK},
}...)

session.AccessRights = map[string]user.AccessDefinition{"dummy": {
session.SetAccessRights(map[string]user.AccessDefinition{"dummy": {
APIID: "dummy",
}}
}})

FallbackKeySesionManager.UpdateSession(storage.HashKey(key), session, 0, config.Global().HashKeys)

Expand Down
Loading

0 comments on commit b5187a8

Please sign in to comment.