Skip to content

Commit

Permalink
Merge 8723fb1 into 7630942
Browse files Browse the repository at this point in the history
  • Loading branch information
bohea committed May 25, 2018
2 parents 7630942 + 8723fb1 commit 1d5b142
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 5 deletions.
95 changes: 91 additions & 4 deletions limits/storage.py
Expand Up @@ -293,10 +293,78 @@ class RedisInteractor(object):
return current
"""

SCRIPT_ACQUIRE_TOKEN = """
local key = KEYS[1]
local intervalPerToken = tonumber(ARGV[2])
local currentTime = tonumber(ARGV[1])
local maxToken = tonumber(ARGV[3])
local initToken = tonumber(ARGV[4])
local maxInterval = tonumber(ARGV[5])
local tokens
local bucket = redis.call("hmget", key, "lastTime", "lastToken")
local lastTime = bucket[1]
local lastToken = bucket[2]
if lastTime == false or lastToken == false then
tokens = initToken
redis.call('hset', key, 'lastTime', currentTime)
else
local thisInterval = currentTime - tonumber(lastTime)
if thisInterval > maxInterval then
tokens = initToken
redis.call('hset', key, 'lastTime', currentTime)
elseif thisInterval > 0 then
local tokensToAdd = math.floor(thisInterval / intervalPerToken)
tokens = math.min(lastToken + tokensToAdd, maxToken)
redis.call('hset', key, 'lastTime', lastTime + intervalPerToken * tokensToAdd)
else
tokens = lastToken
end
end
if tokens == 0 then
redis.call('hset', key, 'lastToken', tokens)
return false
else
redis.call('hset', key, 'lastToken', tokens - 1)
return true
end
"""

SCRIPT_TOKEN_BUCKET = """
local key = KEYS[1]
local intervalPerToken = tonumber(ARGV[2])
local currentTime = tonumber(ARGV[1])
local maxToken = tonumber(ARGV[3])
local initToken = tonumber(ARGV[4])
local maxInterval = tonumber(ARGV[5])
local tokens
local bucket = redis.call("hmget", key, "lastTime", "lastToken")
local lastTime = bucket[1]
local lastToken = bucket[2]
local newTime
if lastTime == false or lastToken == false then
tokens = initToken
newTime = currentTime
else
local thisInterval = currentTime - tonumber(lastTime)
if thisInterval > maxInterval then
tokens = initToken
newTime = currentTime
elseif thisInterval > 0 then
local tokensToAdd = math.floor(thisInterval / intervalPerToken)
tokens = math.min(lastToken + tokensToAdd, maxToken)
newTime = lastTime + intervalPerToken * tokensToAdd
else
tokens = lastToken
newTime = lastTime
end
end
return {newTime, tokens}
"""

def incr(self, key, expiry, connection, elastic_expiry=False):
"""
increments the counter for a given rate limit key
:param connection: Redis connection
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
Expand All @@ -316,7 +384,7 @@ def get(self, key, connection):
def get_moving_window(self, key, limit, expiry):
"""
returns the starting point and the number of entries in the moving window
:param str key: rate limit key
:param int expiry: expiry of entry
"""
Expand Down Expand Up @@ -359,6 +427,12 @@ def check(self, connection):
except: # noqa
return False

def get_token_bucket(self, key, current_time, interval_per_token, max_tokens, init_tokens, max_interval):
return self.lua_token_bucket(
[key],
[current_time, interval_per_token, max_tokens, init_tokens, max_interval]
)


class RedisStorage(RedisInteractor, Storage):
"""
Expand Down Expand Up @@ -393,11 +467,17 @@ def initialize_storage(self, uri):
self.lua_incr_expire = self.storage.register_script(
RedisStorage.SCRIPT_INCR_EXPIRE
)
self.lua_acquire_token = self.storage.register_script(
RedisStorage.SCRIPT_ACQUIRE_TOKEN
)
self.lua_token_bucket = self.storage.register_script(
RedisStorage.SCRIPT_TOKEN_BUCKET
)

def incr(self, key, expiry, elastic_expiry=False):
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
"""
Expand Down Expand Up @@ -442,13 +522,20 @@ def reset(self):
"""WARNING, this operation was designed to be fast, but was not tested
on a large production based system. Be careful with its usage as it
could be slow on very large data sets.
This function calls a Lua Script to delete keys prefixed with 'LIMITER'
in block of 5000."""

cleared = self.lua_clear_keys(['LIMITER*'])
return cleared

def acquire_token(self, key, current_time, interval_per_token,
max_tokens, init_tokens, max_interval):
print(key, current_time, interval_per_token, max_tokens, init_tokens, max_interval)
return self.lua_acquire_token(
[key],
[current_time, interval_per_token, max_tokens, init_tokens, max_interval]
)


class RedisSSLStorage(RedisStorage):
"""
Expand Down
77 changes: 76 additions & 1 deletion limits/strategies.py
Expand Up @@ -178,10 +178,85 @@ def hit(self, item, *identifiers):
item.key_for(*identifiers), item.get_expiry(), True
) <= item.amount
)


class TokenBucketRateLimiter(RateLimiter):
"""
Reference: :ref:`token-bucket`
"""

def __init__(self, storage):
if not (
hasattr(storage, "acquire_token")
or hasattr(storage, "get_window_stats")
):
raise NotImplementedError(
"TokenBucketRateLimiter is not implemented for storage of type %s"
% storage.__class__
)
super(TokenBucketRateLimiter, self).__init__(storage)

def hit(self, item, *identifiers):
"""
creates a hit on the rate limit and returns True if successful.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: True/False
"""
max_tokens = item.amount
max_interval = item.get_expiry() * 1000
init_tokens = int(max_tokens / 3) + 1
interval_per_token = int(max_interval / max_tokens)
return bool(
self.storage().acquire_token(
item.key_for(*identifiers),
int(time.time() * 1000),
interval_per_token,
max_tokens,
init_tokens,
max_interval
)
)

def test(self, item, *identifiers):
"""
checks the rate limit and returns True if it is not
currently exceeded.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: True/False
"""
bucket = self.get_window_stats(item, *identifiers)
tokens = bucket[1]
return tokens > 0

def get_window_stats(self, item, *identifiers):
"""
returns the number of requests remaining and reset of this limit.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: tuple (last refill time (int), remaining token (int))
"""
max_tokens = item.amount
max_interval = item.get_expiry() * 1000
init_tokens = int(max_tokens / 3) + 1
interval_per_token = int(max_interval / max_tokens)
return self.storage().get_token_bucket(
item.key_for(*identifiers),
int(time.time() * 1000),
interval_per_token,
max_tokens,
init_tokens,
max_interval
)


STRATEGIES = {
"fixed-window": FixedWindowRateLimiter,
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
"moving-window": MovingWindowRateLimiter
"moving-window": MovingWindowRateLimiter,
"token-bucket": TokenBucketRateLimiter,
}

0 comments on commit 1d5b142

Please sign in to comment.