Skip to content

Commit

Permalink
Merge cff1099 into 7140d75
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Oct 23, 2017
2 parents 7140d75 + cff1099 commit bd317fa
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 84 deletions.
4 changes: 2 additions & 2 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,13 @@ The `error_message` argument can either be a simple string or a callable that re
def error_handler():
return app.config.get("DEFAULT_ERROR_MESSAGE")
@limiter.limit("1/second", error_message='chill!')
@app.route("/")
@limiter.limit("1/second", error_message='chill!')
def index():
....
@limiter.limit("10/second", error_message=error_handler)
@app.route("/ping")
@limiter.limit("10/second", error_message=error_handler)
def ping():
....
Expand Down
201 changes: 120 additions & 81 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
self._fallback_limiter = None
self.__check_backend_count = 0
self.__last_check_backend = time.time()
self.__marked_for_limiting = {}

class BlackHoleHandler(logging.StreamHandler):
def emit(*_):
Expand Down Expand Up @@ -271,7 +272,7 @@ def check(self):
:raises: RateLimitExceeded
"""
self.__check_request_limit()
self.__check_request_limit(False)

def reset(self):
"""
Expand Down Expand Up @@ -312,7 +313,52 @@ def __inject_headers(self, response):
)
return response

def __check_request_limit(self):
def __evaluate_limits(self, endpoint, limits):
failed_limit = None
limit_for_header = None
for lim in limits:
limit_scope = lim.scope or endpoint
if lim.is_exempt:
return
if lim.methods is not None and request.method.lower(
) not in lim.methods:
return
if lim.per_method:
limit_scope += ":%s" % request.method
limit_key = lim.key_func()

args = [limit_key, limit_scope]
if all(args):
if self._key_prefix:
args = [self._key_prefix] + args
if not limit_for_header or lim.limit < limit_for_header[0]:
limit_for_header = [lim.limit] + args
if not self.limiter.hit(lim.limit, *args):
self.logger.warning(
"ratelimit %s (%s) exceeded at endpoint: %s",
lim.limit, limit_key, limit_scope
)
failed_limit = lim
limit_for_header = [lim.limit] + args
break
else:
self.logger.error(
"Skipping limit: %s. Empty value found in parameters.",
lim.limit
)
continue
g.view_rate_limit = limit_for_header

if failed_limit:
if failed_limit.error_message:
exc_description = failed_limit.error_message if not callable(
failed_limit.error_message
) else failed_limit.error_message()
else:
exc_description = six.text_type(failed_limit.limit)
raise RateLimitExceeded(exc_description)

def __check_request_limit(self, in_middleware=True):
endpoint = request.endpoint or ""
view_func = current_app.view_functions.get(endpoint, None)
name = (
Expand All @@ -325,21 +371,45 @@ def __check_request_limit(self):
or name in self._exempt_routes
or request.blueprint in self._blueprint_exempt
or any(fn() for fn in self._request_filters)
or g.get("_rate_limiting_complete")
):
return
limits = (
name in self._route_limits and self._route_limits[name] or []
limits, dynamic_limits = [], []

# this is to ensure backward compatibility with behavior that
# existed accidentally, i.e::
#
# @limiter.limit(...)
# @app.route('...')
# def func(...):
#
# The above setup would work in pre 1.0 versions because the decorator
# was not acting immediately and instead simply registering the rate
# limiting. The correct way to use the decorator is to wrap
# the limiter with the route, i.e::
#
# @app.route(...)
# @limiter.limit(...)
# def func(...):

implicit_decorator = view_func in self.__marked_for_limiting.get(
name, []
)
dynamic_limits = []
if name in self._dynamic_route_limits:
for lim in self._dynamic_route_limits[name]:
try:
dynamic_limits.extend(list(lim))
except ValueError as e:
self.logger.error(
"failed to load ratelimit for view function %s (%s)",
name, e
)

if not in_middleware or implicit_decorator:
limits = (
name in self._route_limits and self._route_limits[name] or []
)
dynamic_limits = []
if name in self._dynamic_route_limits:
for lim in self._dynamic_route_limits[name]:
try:
dynamic_limits.extend(list(lim))
except ValueError as e:
self.logger.error(
"failed to load ratelimit for view function %s (%s)",
name, e
)
if request.blueprint:
if (request.blueprint in self._blueprint_dynamic_limits
and not dynamic_limits
Expand All @@ -365,66 +435,31 @@ def __check_request_limit(self):
if request.blueprint in self._blueprint_limits and not limits:
limits.extend(self._blueprint_limits[request.blueprint])

failed_limit = None
limit_for_header = None
try:
all_limits = []
if self._storage_dead and self._fallback_limiter:
if self.__should_check_backend() and self._storage.check():
self.logger.info("Rate limit storage recovered")
self._storage_dead = False
self.__check_backend_count = 0
if in_middleware and name in self.__marked_for_limiting:
pass
else:
all_limits = list(
itertools.chain(*self._in_memory_fallback)
)
if not all_limits:
all_limits = itertools.chain(
itertools.chain(*self._application_limits),
(limits + dynamic_limits)
or itertools.chain(*self._default_limits)
)
for lim in all_limits:
limit_scope = lim.scope or endpoint
if lim.is_exempt:
return
if lim.methods is not None and request.method.lower(
) not in lim.methods:
return
if lim.per_method:
limit_scope += ":%s" % request.method
limit_key = lim.key_func()

args = [limit_key, limit_scope]
if all(args):
if self._key_prefix:
args = [self._key_prefix] + args
if not limit_for_header or lim.limit < limit_for_header[0]:
limit_for_header = [lim.limit] + args
if not self.limiter.hit(lim.limit, *args):
self.logger.warning(
"ratelimit %s (%s) exceeded at endpoint: %s",
lim.limit, limit_key, limit_scope
if self.__should_check_backend() and self._storage.check():
self.logger.info("Rate limit storage recovered")
self._storage_dead = False
self.__check_backend_count = 0
else:
all_limits = list(
itertools.chain(*self._in_memory_fallback)
)
failed_limit = lim
limit_for_header = [lim.limit] + args
break
else:
self.logger.error(
"Skipping limit: %s. Empty value found in parameters.",
lim.limit
)
continue
g.view_rate_limit = limit_for_header

if failed_limit:
if failed_limit.error_message:
exc_description = failed_limit.error_message if not callable(
failed_limit.error_message
) else failed_limit.error_message()
else:
exc_description = six.text_type(failed_limit.limit)
raise RateLimitExceeded(exc_description)
if not all_limits:
route_limits = limits + dynamic_limits
all_limits = list(itertools.chain(*self._application_limits))
all_limits += route_limits
if (
not route_limits
and not (in_middleware and name in self.__marked_for_limiting)
or implicit_decorator
):
all_limits += list(itertools.chain(*self._default_limits))
self.__evaluate_limits(endpoint, all_limits)
except Exception as e: # no qa
if isinstance(e, RateLimitExceeded):
six.reraise(*sys.exc_info())
Expand All @@ -434,7 +469,7 @@ def __check_request_limit(self):
" in-memory storage"
)
self._storage_dead = True
self.__check_request_limit()
self.__check_request_limit(in_middleware)
else:
if self._swallow_errors:
self.logger.exception(
Expand All @@ -452,7 +487,7 @@ def __limit_decorator(
per_method=False,
methods=None,
error_message=None,
exempt_when=None
exempt_when=None,
):
_scope = scope if shared else None

Expand Down Expand Up @@ -491,11 +526,7 @@ def _inner(obj):
name, []
).extend(static_limits)
else:

@wraps(obj)
def __inner(*a, **k):
return obj(*a, **k)

self.__marked_for_limiting.setdefault(name, []).append(obj)
if dynamic_limit:
self._dynamic_route_limits.setdefault(
name, []
Expand All @@ -504,8 +535,14 @@ def __inner(*a, **k):
self._route_limits.setdefault(
name, []
).extend(static_limits)
return __inner

@wraps(obj)
def __inner(*a, **k):
if self._auto_check and not g.get("_rate_limiting_complete"):
self.__check_request_limit(False)
g._rate_limiting_complete = True
return obj(*a, **k)
return __inner
return _inner

def limit(
Expand All @@ -515,7 +552,7 @@ def limit(
per_method=False,
methods=None,
error_message=None,
exempt_when=None
exempt_when=None,
):
"""
decorator to be used for rate limiting individual routes or blueprints.
Expand All @@ -530,6 +567,7 @@ def limit(
limited (default: None).
:param error_message: string (or callable that returns one) to override the
error message used in the response.
:param exempt_when:
:return:
"""
return self.__limit_decorator(
Expand All @@ -538,7 +576,7 @@ def limit(
per_method=per_method,
methods=methods,
error_message=error_message,
exempt_when=exempt_when
exempt_when=exempt_when,
)

def shared_limit(
Expand All @@ -547,7 +585,7 @@ def shared_limit(
scope,
key_func=None,
error_message=None,
exempt_when=None
exempt_when=None,
):
"""
decorator to be applied to multiple routes sharing the same rate limit.
Expand All @@ -560,14 +598,15 @@ def shared_limit(
the rate limit. defaults to remote address of the request.
:param error_message: string (or callable that returns one) to override the
error message used in the response.
:param exempt_when:
"""
return self.__limit_decorator(
limit_value,
key_func,
True,
scope,
error_message=error_message,
exempt_when=exempt_when
exempt_when=exempt_when,
)

def exempt(self, obj):
Expand Down
Loading

0 comments on commit bd317fa

Please sign in to comment.