Skip to content

Commit

Permalink
Added redis DECR-based quote check, still writes back to sesison cont…
Browse files Browse the repository at this point in the history
…ext but actual block is habdled by transaction. Fixes #40
  • Loading branch information
Martin Buhr committed Jan 23, 2015
1 parent ec5223e commit a15c551
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 7 deletions.
5 changes: 5 additions & 0 deletions auth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type SessionHandler interface {
RemoveSession(keyName string)
GetSessionDetail(keyName string) (SessionState, bool)
GetSessions(filter string) []string
GetStore() StorageHandler
}

type KeyGenerator interface {
Expand Down Expand Up @@ -83,6 +84,10 @@ func (b *DefaultSessionManager) Init(store StorageHandler) {
b.Store.Connect()
}

func (b *DefaultSessionManager) GetStore() StorageHandler{
return b.Store
}

// UpdateSession updates the session state in the storage engine
func (b DefaultSessionManager) UpdateSession(keyName string, session SessionState, resetTTLTo int64) {
v, _ := json.Marshal(session)
Expand Down
5 changes: 3 additions & 2 deletions middleware_organisation_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ func (k *OrganizationMonitor) ProcessRequest(w http.ResponseWriter, r *http.Requ

return errors.New("This organisation access has been disabled, please contact your API administrator."), 403
}

forwardMessage, reason := sessionLimiter.ForwardMessage(&thisSessionState)

storeRef := k.Spec.OrgSessionManager.GetStore()
forwardMessage, reason := sessionLimiter.ForwardMessage(&thisSessionState, thisOrg, storeRef)

// Ensure quota and rate data for this session are recorded

Expand Down
9 changes: 7 additions & 2 deletions middleware_rate_limiting.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ func (k *RateLimitAndQuotaCheck) ProcessRequest(w http.ResponseWriter, r *http.R
sessionLimiter := SessionLimiter{}
thisSessionState := context.Get(r, SessionData).(SessionState)
authHeaderValue := context.Get(r, AuthHeaderValue).(string)
forwardMessage, reason := sessionLimiter.ForwardMessage(&thisSessionState)


storeRef := k.Spec.SessionManager.GetStore()
forwardMessage, reason := sessionLimiter.ForwardMessage(&thisSessionState, authHeaderValue, storeRef)

// Ensure quota and rate data for this session are recorded
k.Spec.SessionManager.UpdateSession(authHeaderValue, thisSessionState, 0)

// Write it back to the context so we can pass it back in the header
context.Set(r, SessionData, thisSessionState)

log.Debug("SessionState: ", thisSessionState)

Expand Down
42 changes: 40 additions & 2 deletions session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"time"
"strconv"
)

// AccessDefinition defines which versions of an API a key has access to
Expand Down Expand Up @@ -52,7 +53,7 @@ type SessionLimiter struct{}

// ForwardMessage will enforce rate limiting, returning false if session limits have been exceeded.
// Key values to manage rate are Rate and Per, e.g. Rate of 10 messages Per 10 seconds
func (l SessionLimiter) ForwardMessage(currentSession *SessionState) (bool, int) {
func (l SessionLimiter) ForwardMessage(currentSession *SessionState, key string, store StorageHandler) (bool, int) {

current := time.Now().Unix()

Expand All @@ -70,7 +71,7 @@ func (l SessionLimiter) ForwardMessage(currentSession *SessionState) (bool, int)
}

currentSession.Allowance--
if !l.IsQuotaExceeded(currentSession) {
if !l.IsRedisQuotaExceeded(currentSession, key, store) {
return true, 0
}

Expand Down Expand Up @@ -107,6 +108,43 @@ func (l SessionLimiter) IsQuotaExceeded(currentSession *SessionState) bool {

}

func (l SessionLimiter) IsRedisQuotaExceeded(currentSession *SessionState, key string, store StorageHandler) bool {

// Are they unlimited?
if currentSession.QuotaMax == -1 {
// No quota set
return false
}

// Create the key
rawKey := "quota-" + key
quotaVal, rErr := store.GetKey(rawKey)

if rErr != nil {
// Key not found, must have expired, set it to Max and expire when time to renew
store.SetKey(rawKey, strconv.Itoa(int(currentSession.QuotaMax)), currentSession.QuotaRenewalRate)

// Quota has renewed, let them through and set quotmax for records
currentSession.QuotaRemaining = currentSession.QuotaMax
return false
}

remaining, _ := strconv.Atoi(quotaVal)

if remaining > 0 {
// Decrement the quota
store.Decrement(rawKey)
currentSession.QuotaRemaining = int64(remaining)
// Let them through
return false
}

// They have hit zero, none shall pass
currentSession.QuotaRemaining = 0
return true

}

// createSampleSession is a debug function to create a mock session value
func createSampleSession() SessionState {
var thisSession SessionState
Expand Down
29 changes: 28 additions & 1 deletion storage_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type StorageHandler interface {
GetKeysAndValues() map[string]string
GetKeysAndValuesWithFilter(string) map[string]string
DeleteKeys([]string) bool
Decrement(string)
}

// InMemoryStorageManager implements the StorageHandler interface,
Expand All @@ -35,7 +36,11 @@ type InMemoryStorageManager struct {
Sessions map[string]string
}

// Connect will establish a connection to the storage engine
// Decrement is a dummy function
func (s *InMemoryStorageManager) Decrement(n string) {
log.Warning("Not implemented!")
}

func (s *InMemoryStorageManager) Connect() bool {
return true
}
Expand Down Expand Up @@ -223,6 +228,28 @@ func (r *RedisStorageManager) SetKey(keyName string, sessionState string, timeou
}
}

// Decrement will decrement a key in redis with a transaction
func (r *RedisStorageManager) Decrement(keyName string) {
db := r.pool.Get()
defer db.Close()

keyName = r.fixKey(keyName)
log.Debug("Decrementing key: ", keyName)
if db == nil {
log.Info("Connection dropped, connecting..")
r.Connect()
r.Decrement(keyName)
} else {
db.Send("MULTI")
db.Send("DECR", keyName)
_, err := db.Do("EXEC")

if err != nil {
log.Error("Error trying to decrement value:", err)
}
}
}

// GetKeys will return all keys according to the filter (filter is a prefix - e.g. tyk.keys.*)
func (r *RedisStorageManager) GetKeys(filter string) []string {
db := r.pool.Get()
Expand Down

0 comments on commit a15c551

Please sign in to comment.