Skip to content

Commit

Permalink
partial implementation for rate limit headers
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 24, 2014
1 parent 1b61feb commit aa8fc1a
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 9 deletions.
32 changes: 32 additions & 0 deletions flask_limiter/backports/total_ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
total_ordering backport from http://code.activestate.com/recipes/576685/
"""

def total_ordering(cls):
'Class decorator that fills-in missing ordering methods'
convert = {
'__lt__': [('__gt__', lambda self, other: other < self),
('__le__', lambda self, other: not other < self),
('__ge__', lambda self, other: not self < other)],
'__le__': [('__ge__', lambda self, other: other <= self),
('__lt__', lambda self, other: not other <= self),
('__gt__', lambda self, other: not self <= other)],
'__gt__': [('__lt__', lambda self, other: other > self),
('__ge__', lambda self, other: not other > self),
('__le__', lambda self, other: not self > other)],
'__ge__': [('__le__', lambda self, other: other >= self),
('__gt__', lambda self, other: not other >= self),
('__lt__', lambda self, other: not self >= other)]
}
if hasattr(object, '__lt__'):
roots = [op for op in convert if getattr(cls, op) is not getattr(object, op)]
else:
roots = set(dir(cls)) & set(convert)
assert roots, 'must define at least one ordering operation: < > <= >='
root = max(roots) # prefer __lt __ to __le__ to __gt__ to __ge__
for opname, opfunc in convert[root]:
if opname not in roots:
opfunc.__name__ = opname
opfunc.__doc__ = getattr(int, opname).__doc__
setattr(cls, opname, opfunc)
return cls
38 changes: 36 additions & 2 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from functools import wraps
import logging

from flask import request, current_app
from flask import request, current_app, g
import time

from .errors import RateLimitExceeded, ConfigurationError
from .strategies import STRATEGIES
Expand All @@ -21,11 +22,16 @@ class Limiter(object):
Defaults to the remote address of the request.
"""

def __init__(self, app=None, key_func=get_ipaddr, global_limits=[]):
def __init__(self, app=None
, key_func=get_ipaddr
, global_limits=[]
, headers_enabled=False
):
self.app = app
self.enabled = True
self.global_limits = []
self.exempt_routes = []
self.headers_enabled = headers_enabled
for limit in global_limits:
self.global_limits.extend(
[
Expand All @@ -49,6 +55,10 @@ def init_app(self, app):
:param app: :class:`flask.Flask` instance to rate limit.
"""
self.enabled = app.config.setdefault("RATELIMIT_ENABLED", True)
self.headers_enabled = (
self.headers_enabled
or app.config.setdefault("RATELIMIT_HEADERS_ENABLED", False)
)
self.storage = storage_from_string(
app.config.setdefault('RATELIMIT_STORAGE_URL', 'memory://')
)
Expand All @@ -62,6 +72,24 @@ def init_app(self, app):
(self.key_func, limit) for limit in parse_many(conf_limits)
]
app.before_request(self.__check_request_limit)
app.after_request(self.__inject_headers)

def __inject_headers(self, response):
current_limit = getattr(g, '_view_header_rate_limit', None)
if self.enabled and self.headers_enabled and current_limit:
response.headers.add(
'X-RateLimit-Limit',
str(current_limit[0].amount)
)
response.headers.add(
'X-RateLimit-Remaining',
str(self.limiter.get_remaining(*current_limit))
)
response.headers.add(
'X-RateLimit-Reset',
str(self.limiter.get_refresh(*current_limit))
)
return response

def __check_request_limit(self):
endpoint = request.endpoint or ""
Expand Down Expand Up @@ -92,12 +120,18 @@ def __check_request_limit(self):
)

failed_limit = None
limit_for_header = None
for key_func, limit in (limits + dynamic_limits or self.global_limits):
if not limit_for_header or limit < limit_for_header[0]:
limit_for_header = (limit, key_func(), endpoint)
if not self.limiter.hit(limit, key_func(), endpoint):
self.logger.warn(
"ratelimit %s (%s) exceeded at endpoint: %s" % (
limit, key_func(), endpoint))
failed_limit = limit
limit_for_header = (limit, key_func(), endpoint)
g._view_header_rate_limit = limit_for_header

if failed_limit:
raise RateLimitExceeded(failed_limit)

Expand Down
9 changes: 7 additions & 2 deletions flask_limiter/limits.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
"""

from six import add_metaclass

try:
from functools import total_ordering
except ImportError:
from .backports.total_ordering import total_ordering # pragma: no cover

TIME_TYPES = dict(
DAY=(60 * 60 * 24, "day"),
Expand All @@ -28,6 +30,7 @@ def __new__(cls, name, parents, dct):

#pylint: disable=no-member
@add_metaclass(RateLimitItemMeta)
@total_ordering
class RateLimitItem(object):
"""
defines a Rate limited resource which contains characteristics
Expand Down Expand Up @@ -80,6 +83,8 @@ def __repr__(self):
self.amount, self.multiples, self.granularity[1]
)

def __lt__(self, other):
return self.granularity[0] < other.granularity[0]

#pylint: disable=invalid-name
class PER_YEAR(RateLimitItem):
Expand Down
59 changes: 56 additions & 3 deletions flask_limiter/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def get(self, key):
"""
raise NotImplementedError

@abstractmethod
def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
raise NotImplementedError




class LockableEntry(threading._RLock):
__slots__ = ["atime", "expiry"]
def __init__(self, expiry):
Expand Down Expand Up @@ -125,15 +135,34 @@ def acquire_entry(self, key, limit, expiry, no_add=False):
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
with entry:
if entry in self.events[key]:
self.events[key].remove(entry)
return False
else:
if not no_add:
self.events[key].insert(0, LockableEntry(expiry))
return True

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.expirations.get(key, -1))

def get_acquirable(self, key, limit, expiry):
"""
returns the number of acquirable entries
:param str key: rate limit key to acquire an entry in
:param int limit: amount of entries allowed
:param int expiry: expiry of the entry
"""
timestamp = time.time()
if not self.events.get(key):
return limit
else:
return limit - len(
[k for k in self.events[key] if k.atime >= timestamp - expiry]
)

class RedisStorage(Storage):
"""
rate limit storage with redis as backend
Expand Down Expand Up @@ -192,6 +221,22 @@ def acquire_entry(self, key, limit, expiry, no_add=False):
pipeline.execute()
return True

def get_acquirable(self, key, limit, expiry):
"""
returns the number of acquirable entries
:param str key: rate limit key to acquire an entry in
:param int limit: amount of entries allowed
:param int expiry: expiry of the entry
"""
raise NotImplementedError

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.storage.ttl(key) + time.time())

class MemcachedStorage(Storage):
"""
rate limit storage with memcached as backend
Expand Down Expand Up @@ -246,8 +291,16 @@ def incr(self, key, expiry, elastic_expiry=False):
and retry < self.MAX_CAS_RETRIES
):
value, cas = self.storage.gets(key)
self.storage.add(key + "/expires", expiry)
retry += 1
return int(value) + 1
else:
return self.storage.incr(key, 1)
return 1

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.storage.get(key + "/expires"))

84 changes: 84 additions & 0 deletions flask_limiter/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod
import weakref
import six
import time


@six.add_metaclass(ABCMeta)
Expand Down Expand Up @@ -36,8 +37,33 @@ def check(self, item, *identifiers):
"""
raise NotImplementedError

@abstractmethod
def get_remaining(self, item, *identifiers):
"""
returns the number of requests remaining within this limit.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
raise NotImplementedError

@abstractmethod
def get_refresh(self, item, *identifiers):
"""
returns the UTC time when the window will be refreshed
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
raise NotImplementedError


class MovingWindowRateLimiter(RateLimiter):

def __init__(self, storage):
if not hasattr(storage, "acquire_entry"):
raise NotImplementedError("MovingWindowRateLimiting is not implemented for storage of type %s" % storage.__class__)
Expand Down Expand Up @@ -65,6 +91,27 @@ def check(self, item, *identifiers):
"""
return self.storage().acquire_entry(item.key_for(*identifiers), item.amount, item.expiry, True)

def get_remaining(self, item, *identifiers):
"""
returns the number of requests remaining within this limit.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
return self.storage().get_acquirable(item.key_for(*identifiers), item.amount, item.expiry)

def get_refresh(self, item, *identifiers):
"""
returns the UTC time when the window will be refreshed
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
return int(time.time() + 1)

class FixedWindowRateLimiter(RateLimiter):
def hit(self, item, *identifiers):
Expand All @@ -91,6 +138,27 @@ def check(self, item, *identifiers):
"""
return self.storage().get(item.key_for(*identifiers)) <= item.amount

def get_remaining(self, item, *identifiers):
"""
returns the number of requests remaining within this limit.
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
return max(0, item.amount - self.storage().get(item.key_for(*identifiers)))

def get_refresh(self, item, *identifiers):
"""
returns the UTC time when the window will be refreshed
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
return self.storage().get_expiry(item.key_for(*identifiers))

class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
def hit(self, item, *identifiers):
Expand All @@ -107,6 +175,22 @@ def hit(self, item, *identifiers):
<= item.amount
)

def get_refresh(self, item, *identifiers):
"""
returns the UTC time when the window will be refreshed
:param item: a :class:`RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:return: int
"""
penalty = 0
if not self.check():
penalty = item.amount
return super(FixedWindowElasticExpiryRateLimiter, self).get_refresh(
item, *identifiers
) + penalty

STRATEGIES = {
"fixed-window": FixedWindowRateLimiter,
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_flask_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,5 +407,3 @@ def t():
str(int(time.time()))
)
timeline.forward(49)
resp = cli.get("/t1")
print(resp.headers)

0 comments on commit aa8fc1a

Please sign in to comment.