Skip to content

Commit

Permalink
implementation for shared limits (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Jun 24, 2014
1 parent f93f383 commit 353ab2e
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 25 deletions.
68 changes: 68 additions & 0 deletions doc/source/index.rst
Expand Up @@ -142,6 +142,74 @@ instance are
.. note:: The callable is called from within a
:ref:`flask request context <flask:request-context>`.

.. _ratelimit-decorator-shared-limit:

:meth:`Limiter.shared_limit`
For scenarios where a rate limit should be shared by multiple routes
(For example when you want to protect routes using the same resource
with an umbrella rate limit).

Named shared limit

.. code-block:: python
mysql_limit = limiter.limit("100/hour", scope="mysql")
@app.route("..")
@mysql_limit
def r1():
...
@app.route("..")
@mysql_limit
def r2():
...
Anonymous shared limit: when no scope is provided a unique id will be used.

.. code-block:: python
shared_limit = limiter.limit("100/hour")
@app.route("..")
@shared_limit
def r1():
...
@app.route("..")
@shared_limit
def r2():
...
Dynamic shared limit: when a callable is passed as scope, the return value
of the function will be used as the scope.

.. code-block:: python
def host_scope():
return request.host
host_limit = limiter.shared_limit("100/hour", scope=host_scope)
@app.route("..")
@host_limit
def r1():
...
@app.route("..")
@host_limit
def r2():
...
.. note:: Shared rate limits provide the same conveniences as individual rate limits,

* They can be chained with other shared limits or other individual limits
* They accept keying functions
* accept callables to determine the rate limit value



.. _ratelimit-decorator-exempt:

:meth:`Limiter.exempt`
Expand Down
56 changes: 40 additions & 16 deletions flask_limiter/extension.py
Expand Up @@ -4,6 +4,7 @@

from functools import wraps
import logging
import uuid

from flask import request, current_app, g

Expand Down Expand Up @@ -140,7 +141,9 @@ def __check_request_limit(self):
failed_limit = None
limit_for_header = None
for key_func, limit, scope in (limits + dynamic_limits or self.global_limits):
limit_scope = scope or endpoint
limit_scope = (
scope if not callable(scope) else scope(endpoint)
) or endpoint
if not limit_for_header or limit < limit_for_header[0]:
limit_for_header = (limit, key_func(), limit_scope)
if not self.limiter.hit(limit, key_func(), limit_scope):
Expand All @@ -155,42 +158,63 @@ def __check_request_limit(self):
if failed_limit:
raise RateLimitExceeded(failed_limit)

def limit(self, limit_value, key_func=None, scope=None):
"""
decorator to be used for rate limiting specific routes.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
:param key_func: function/lambda to extract the unique identifier for
the rate limit. defaults to remote address of the request.
:param scope: a string for defining the rate limiting scope.
defaults to the route's endpoint.
:return:
"""
def __limit_decorator(self, limit_value,
key_func=None, shared=False,
scope=None):
_scope = scope or uuid.uuid1().hex if shared else None

def _inner(fn):
name = "%s.%s" % (fn.__module__, fn.__name__)

@wraps(fn)
def __inner(*a, **k):
return fn(*a, **k)
func = key_func or self.key_func
if callable(limit_value):
self.dynamic_route_limits.setdefault(name, []).append(
(func, limit_value, scope)
(func, limit_value, _scope)
)
else:
try:
self.route_limits.setdefault(name, []).extend(
[(func, limit, scope) for limit in parse_many(limit_value)]
[(func, limit, _scope) for limit in parse_many(limit_value)]
)
except ValueError as e:
self.logger.error(
"failed to configure view function %s (%s)", name, e
)

return __inner
return _inner


def limit(self, limit_value, key_func=None):
"""
decorator to be used for rate limiting individual routes.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
:param key_func: function/lambda to extract the unique identifier for
the rate limit. defaults to remote address of the request.
:return:
"""
return self.__limit_decorator(limit_value, key_func)


def shared_limit(self, limit_value, key_func=None, scope=None):
"""
decorator to be applied to multiple routes sharing the same rate limit.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
:param key_func: function/lambda to extract the unique identifier for
the rate limit. defaults to remote address of the request.
:param scope: a string or callable that returns a string
for defining the rate limiting scope. Defaults to a :func:`uuid.uuid1` value.
"""
return self.__limit_decorator(limit_value, key_func, True, scope)



def exempt(self, fn):
"""
decorator to mark a view as exempt from rate limits.
Expand Down
79 changes: 70 additions & 9 deletions tests/test_flask_ext.py
Expand Up @@ -423,25 +423,47 @@ def t():
str(int(time.time() + 49))
)

def test_scope(self):
app = Flask(__name__)
limiter = Limiter(app)
mock_handler = mock.Mock()
mock_handler.level = logging.INFO
limiter.logger.addHandler(mock_handler)
def test_unnamed_shared_limit(self):
app, limiter = self.build_app()
shared_limit_a = limiter.shared_limit("1/minute")
shared_limit_b = limiter.shared_limit("1/minute")
@app.route("/t1")
@shared_limit_a
def route1():
return "route1"

@app.route("/t2")
@shared_limit_a
def route2():
return "route2"

@app.route("/t3")
@shared_limit_b
def route3():
return "route3"

with hiro.Timeline().freeze() as timeline:
with app.test_client() as cli:
self.assertEqual(200, cli.get("/t1").status_code)
self.assertEqual(200, cli.get("/t3").status_code)
self.assertEqual(429, cli.get("/t2").status_code)

def test_named_shared_limit(self):
app, limiter = self.build_app()
shared_limit_a = limiter.shared_limit("1/minute", scope='a')
shared_limit_b = limiter.shared_limit("1/minute", scope='b')
@app.route("/t1")
@limiter.limit("1/minute", scope="a")
@shared_limit_a
def route1():
return "route1"

@app.route("/t2")
@limiter.limit("1/minute", scope="a")
@shared_limit_a
def route2():
return "route2"

@app.route("/t3")
@limiter.limit("1/minute", scope="b")
@shared_limit_b
def route3():
return "route3"

Expand All @@ -451,6 +473,45 @@ def route3():
self.assertEqual(200, cli.get("/t3").status_code)
self.assertEqual(429, cli.get("/t2").status_code)

def test_dynamic_shared_limit(self):
app, limiter = self.build_app()
fn_a = mock.Mock()
fn_b = mock.Mock()
fn_a.return_value = "foo"
fn_b.return_value = "bar"


dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a)
dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b)


@app.route("/t1")
@dy_limit_a
def t1():
return "route1"

@app.route("/t2")
@dy_limit_a
def t2():
return "route2"

@app.route("/t3")
@dy_limit_b
def t3():
return "route3"

with hiro.Timeline().freeze():
with app.test_client() as cli:
self.assertEqual(200, cli.get("/t1").status_code)
self.assertEqual(200, cli.get("/t3").status_code)
self.assertEqual(429, cli.get("/t2").status_code)
self.assertEqual(429, cli.get("/t3").status_code)
self.assertEqual(2, fn_a.call_count)
self.assertEqual(2, fn_b.call_count)
fn_b.assert_called_with("t3")
fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")])


def test_whitelisting(self):

app = Flask(__name__)
Expand Down

0 comments on commit 353ab2e

Please sign in to comment.