Skip to content

Commit

Permalink
rate limit headers
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 25, 2014
1 parent e859623 commit 4773a22
Show file tree
Hide file tree
Showing 8 changed files with 404 additions and 9 deletions.
34 changes: 34 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ The following flask configuration values are honored by
while memcached relies on the `pymemcache`_ package.
``RATELIMIT_STRATEGY`` The rate limiting strategy to use. :ref:`ratelimit-strategy`
for details.
``RATELIMIT_HEADERS_ENABLED`` Enables returning :ref:`ratelimit-headers`. Defaults to ``False``
``RATELIMIT_ENABLED`` Overall killswitch for ratelimits. Defaults to ``True``
============================== ================================================

Expand Down Expand Up @@ -242,6 +243,39 @@ rate limit as the window for each limit is not fixed at the start and end of eac
however a higher memory cost associated with this strategy as it requires ``N`` items to
be maintained in memory per resource and rate limit.

.. _ratelimit-headers:

Rate-limiting Headers
=====================

If the configuration is enabled, information about the rate limit with respect to the
route being requested will be written as part of the response. Since multiple rate limits
can be active for a given route - the rate limit with the lowest time granularity will be
used in the scenario when the request does not breach any rate limits.

.. tabularcolumns:: |p{6.5cm}|p{8.5cm}|

============================== ================================================
``X-RateLimit-Limit`` The total number of requests allowed for the
active window
``X-RateLimit-Remaining`` The number of requests remaining in the active
window.
``X-RateLimit-Reset`` UTC seconds since epoch when the window will be
reset.
============================== ================================================

Depending on the :ref:`ratelimit-strategy` chosen, the meaning of the headers
may differ. For example, with a moving window strategy there is no actual
reset for the window, and therefore the value of ``X-RateLimit-Reset`` is always
``now() + 1``.

.. warning:: Enabling the headers has an additional with certain storage / strategy combinations.

* Memcached + Fixed Window: an extra key per rate limit is stored to calculate
``X-RateLimit-Reset``
* Redis + Moving Window: an extra call to redis is involved during every request
to calculate ``X-RateLimit-Remaining``

.. _keyfunc-customization:

Customization
Expand Down
32 changes: 32 additions & 0 deletions flask_limiter/backports/total_ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
total_ordering backport from http://code.activestate.com/recipes/576685/
"""

def total_ordering(cls):
'Class decorator that fills-in missing ordering methods'
convert = {
'__lt__': [('__gt__', lambda self, other: other < self),
('__le__', lambda self, other: not other < self),
('__ge__', lambda self, other: not self < other)],
'__le__': [('__ge__', lambda self, other: other <= self),
('__lt__', lambda self, other: not other <= self),
('__gt__', lambda self, other: not self <= other)],
'__gt__': [('__lt__', lambda self, other: other > self),
('__ge__', lambda self, other: not other > self),
('__le__', lambda self, other: not self > other)],
'__ge__': [('__le__', lambda self, other: other >= self),
('__gt__', lambda self, other: not other >= self),
('__lt__', lambda self, other: not self >= other)]
}
if hasattr(object, '__lt__'):
roots = [op for op in convert if getattr(cls, op) is not getattr(object, op)]
else:
roots = set(dir(cls)) & set(convert)
assert roots, 'must define at least one ordering operation: < > <= >='
root = max(roots) # prefer __lt __ to __le__ to __gt__ to __ge__
for opname, opfunc in convert[root]:
if opname not in roots:
opfunc.__name__ = opname
opfunc.__doc__ = getattr(int, opname).__doc__
setattr(cls, opname, opfunc)
return cls
39 changes: 37 additions & 2 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from functools import wraps
import logging

from flask import request, current_app
from flask import request, current_app, g
import time

from .errors import RateLimitExceeded, ConfigurationError
from .strategies import STRATEGIES
Expand All @@ -19,13 +20,19 @@ class Limiter(object):
limits to apply to all routes. :ref:`ratelimit-string` for more details.
:param function key_func: a callable that returns the domain to rate limit by.
Defaults to the remote address of the request.
:param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
"""

def __init__(self, app=None, key_func=get_ipaddr, global_limits=[]):
def __init__(self, app=None
, key_func=get_ipaddr
, global_limits=[]
, headers_enabled=False
):
self.app = app
self.enabled = True
self.global_limits = []
self.exempt_routes = []
self.headers_enabled = headers_enabled
for limit in global_limits:
self.global_limits.extend(
[
Expand All @@ -49,6 +56,10 @@ def init_app(self, app):
:param app: :class:`flask.Flask` instance to rate limit.
"""
self.enabled = app.config.setdefault("RATELIMIT_ENABLED", True)
self.headers_enabled = (
self.headers_enabled
or app.config.setdefault("RATELIMIT_HEADERS_ENABLED", False)
)
self.storage = storage_from_string(
app.config.setdefault('RATELIMIT_STORAGE_URL', 'memory://')
)
Expand All @@ -62,6 +73,24 @@ def init_app(self, app):
(self.key_func, limit) for limit in parse_many(conf_limits)
]
app.before_request(self.__check_request_limit)
app.after_request(self.__inject_headers)

def __inject_headers(self, response):
current_limit = getattr(g, '_view_header_rate_limit', None)
if self.enabled and self.headers_enabled and current_limit:
response.headers.add(
'X-RateLimit-Limit',
str(current_limit[0].amount)
)
response.headers.add(
'X-RateLimit-Remaining',
str(self.limiter.get_remaining(*current_limit))
)
response.headers.add(
'X-RateLimit-Reset',
str(self.limiter.get_refresh(*current_limit))
)
return response

def __check_request_limit(self):
endpoint = request.endpoint or ""
Expand Down Expand Up @@ -92,12 +121,18 @@ def __check_request_limit(self):
)

failed_limit = None
limit_for_header = None
for key_func, limit in (limits + dynamic_limits or self.global_limits):
if not limit_for_header or limit < limit_for_header[0]:
limit_for_header = (limit, key_func(), endpoint)
if not self.limiter.hit(limit, key_func(), endpoint):
self.logger.warn(
"ratelimit %s (%s) exceeded at endpoint: %s" % (
limit, key_func(), endpoint))
failed_limit = limit
limit_for_header = (limit, key_func(), endpoint)
g._view_header_rate_limit = limit_for_header

if failed_limit:
raise RateLimitExceeded(failed_limit)

Expand Down
9 changes: 7 additions & 2 deletions flask_limiter/limits.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
"""

from six import add_metaclass

try:
from functools import total_ordering
except ImportError:
from .backports.total_ordering import total_ordering # pragma: no cover

TIME_TYPES = dict(
DAY=(60 * 60 * 24, "day"),
Expand All @@ -28,6 +30,7 @@ def __new__(cls, name, parents, dct):

#pylint: disable=no-member
@add_metaclass(RateLimitItemMeta)
@total_ordering
class RateLimitItem(object):
"""
defines a Rate limited resource which contains characteristics
Expand Down Expand Up @@ -80,6 +83,8 @@ def __repr__(self):
self.amount, self.multiples, self.granularity[1]
)

def __lt__(self, other):
return self.granularity[0] < other.granularity[0]

#pylint: disable=invalid-name
class PER_YEAR(RateLimitItem):
Expand Down
77 changes: 74 additions & 3 deletions flask_limiter/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def get(self, key):
"""
raise NotImplementedError

@abstractmethod
def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
raise NotImplementedError




class LockableEntry(threading._RLock):
__slots__ = ["atime", "expiry"]
def __init__(self, expiry):
Expand Down Expand Up @@ -125,15 +135,34 @@ def acquire_entry(self, key, limit, expiry, no_add=False):
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
with entry:
if entry in self.events[key]:
self.events[key].remove(entry)
return False
else:
if not no_add:
self.events[key].insert(0, LockableEntry(expiry))
return True

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.expirations.get(key, -1))

def get_acquirable(self, key, limit, expiry):
"""
returns the number of acquirable entries
:param str key: rate limit key to acquire an entry in
:param int limit: amount of entries allowed
:param int expiry: expiry of the entry
"""
timestamp = time.time()
if not self.events.get(key):
return limit
else:
return limit - len(
[k for k in self.events[key] if k.atime >= timestamp - expiry]
)

class RedisStorage(Storage):
"""
rate limit storage with redis as backend
Expand All @@ -149,6 +178,20 @@ def __init__(self, redis_url):
self.storage = get_dependency("redis").from_url(redis_url)
if not self.storage.ping():
raise ConfigurationError("unable to connect to redis at %s" % redis_url) # pragma: no cover
script = """
local items = redis.call('lrange', KEYS[1], 0, tonumber(ARGV[2]))
local expiry = tonumber(ARGV[1])
local a = 0
for idx=1,#items do
if tonumber(items[idx]) >= expiry then
a = a + 1
else
break
end
end
return a
"""
self.script_hash = self.storage.script_load(script)
super(RedisStorage, self).__init__()

def incr(self, key, expiry, elastic_expiry=False):
Expand Down Expand Up @@ -192,6 +235,25 @@ def acquire_entry(self, key, limit, expiry, no_add=False):
pipeline.execute()
return True

def get_acquirable(self, key, limit, expiry):
"""
returns the number of acquirable entries
:param str key: rate limit key to acquire an entry in
:param int limit: amount of entries allowed
:param int expiry: expiry of the entry
"""
timestamp = time.time()
return limit - self.storage.evalsha(
self.script_hash, 1, key, int(timestamp - expiry), limit
)

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.storage.ttl(key) + time.time())

class MemcachedStorage(Storage):
"""
rate limit storage with memcached as backend
Expand Down Expand Up @@ -247,7 +309,16 @@ def incr(self, key, expiry, elastic_expiry=False):
):
value, cas = self.storage.gets(key)
retry += 1
self.storage.set(key + "/expires", expiry + time.time(), expire=expiry, noreply=False)
return int(value) + 1
else:
return self.storage.incr(key, 1)
self.storage.set(key + "/expires", expiry + time.time(), expire=expiry, noreply=False)
return 1

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(float(self.storage.get(key + "/expires") or time.time()))

Loading

0 comments on commit 4773a22

Please sign in to comment.