Skip to content

Commit

Permalink
Simplify limit management for decorated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Dec 28, 2022
1 parent 2677705 commit d52d662
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 72 deletions.
60 changes: 12 additions & 48 deletions flask_limiter/extension.py
Expand Up @@ -240,7 +240,6 @@ def __init__(
self.limit_manager = LimitManager(
application_limits=_application_limits,
default_limits=_default_limits,
static_decorated_limits={},
dynamic_decorated_limits={},
static_blueprint_limits={},
dynamic_blueprint_limits={},
Expand Down Expand Up @@ -1075,10 +1074,7 @@ def __init__(
self.cost = cost
self.is_static = not callable(self.limit_value)
self.shared = shared

@property
def dynamic_limit(self) -> Optional[LimitGroup]:
return LimitGroup(
self.limit_group: Optional[LimitGroup] = LimitGroup(
limit_provider=self.limit_value,
key_function=self.key_func,
scope=self.scope,
Expand All @@ -1093,40 +1089,16 @@ def dynamic_limit(self) -> Optional[LimitGroup]:
shared=self.shared,
)

@property
def static_limits(self) -> List[Limit]:
return list(
LimitGroup(
limit_provider=self.limit_value,
key_function=self.key_func,
scope=self.scope,
per_method=self.per_method,
methods=self.methods,
error_message=self.error_message,
exempt_when=self.exempt_when,
override_defaults=self.override_defaults,
deduct_when=self.deduct_when,
on_breach=self.on_breach,
cost=self.cost,
shared=self.shared,
)
)

def __enter__(self) -> None:
tb = traceback.extract_stack(limit=2)
qualified_location = f"{tb[0].filename}:{tb[0].name}:{tb[0].lineno}"

# TODO: if use as a context manager becomes interesting/valuable
# a less hacky approach than using the traceback and piggy backing
# on the limit manager's knowledge of decorated limits might be worth it.
if not self.is_static:
self.limiter.limit_manager.add_decorated_runtime_limit(
qualified_location, self.dynamic_limit, override=True
)
else:
self.limiter.limit_manager.add_decorated_static_limit(
qualified_location, *self.static_limits, override=True
)
self.limiter.limit_manager.add_decorated_limit(
qualified_location, self.limit_group, override=True
)

self.limiter.limit_manager.add_endpoint_hint(
flask.request.endpoint, qualified_location
Expand Down Expand Up @@ -1160,12 +1132,11 @@ def __call__(
name = obj.name
else:
name = get_qualified_name(obj)
dynamic_limit = self.dynamic_limit if not self.is_static else None
static_limits = []
if self.is_static:
if self.is_static and self.limit_group:
try:
static_limits = self.static_limits
list(self.limit_group)
except ValueError as e:
self.limit_group = None
self.limiter.logger.error(
"failed to configure %s %s (%s)",
"view function" if is_route else "blueprint",
Expand All @@ -1175,24 +1146,17 @@ def __call__(

if isinstance(obj, flask.Blueprint):
if not self.is_static:
self.limiter.limit_manager.add_runtime_blueprint_limits(
name, dynamic_limit
self.limiter.limit_manager.add_runtime_blueprint_limit(
name, self.limit_group
)
else:
elif self.limit_group:
self.limiter.limit_manager.add_static_blueprint_limits(
name, *static_limits
name, *list(self.limit_group)
)
return None
else:
self.limiter._marked_for_limiting.add(name)
if not self.is_static:
self.limiter.limit_manager.add_decorated_runtime_limit(
name, dynamic_limit
)
else:
self.limiter.limit_manager.add_decorated_static_limit(
name, *static_limits
)
self.limiter.limit_manager.add_decorated_limit(name, self.limit_group)

@wraps(obj)
def __inner(*a: P.args, **k: P.kwargs) -> R:
Expand Down
39 changes: 15 additions & 24 deletions flask_limiter/manager.py
Expand Up @@ -17,7 +17,6 @@ def __init__(
self,
application_limits: List[LimitGroup],
default_limits: List[LimitGroup],
static_decorated_limits: Dict[str, OrderedSet[Limit]],
dynamic_decorated_limits: Dict[str, OrderedSet[LimitGroup]],
static_blueprint_limits: Dict[str, OrderedSet[Limit]],
dynamic_blueprint_limits: Dict[str, OrderedSet[LimitGroup]],
Expand All @@ -26,8 +25,7 @@ def __init__(
) -> None:
self._application_limits = application_limits
self._default_limits = default_limits
self._static_decorated_limits = static_decorated_limits
self._runtime_decorated_limits = dynamic_decorated_limits
self._decorated_limits = dynamic_decorated_limits
self._static_blueprint_limits = static_blueprint_limits
self._runtime_blueprint_limits = dynamic_blueprint_limits
self._route_exemptions = route_exemptions
Expand All @@ -49,26 +47,22 @@ def set_application_limits(self, limits: List[LimitGroup]) -> None:
def set_default_limits(self, limits: List[LimitGroup]) -> None:
self._default_limits = limits

def add_decorated_runtime_limit(
self, route: str, limit: LimitGroup, override: bool = False
def add_decorated_limit(
self, route: str, limit: Optional[LimitGroup], override: bool = False
) -> None:
if not override:
self._runtime_decorated_limits.setdefault(route, OrderedSet()).add(limit)
else:
self._runtime_decorated_limits[route] = OrderedSet([limit])

def add_runtime_blueprint_limits(self, blueprint: str, limit: LimitGroup) -> None:
self._runtime_blueprint_limits.setdefault(blueprint, OrderedSet()).add(limit)
if limit:
if not override:
self._decorated_limits.setdefault(route, OrderedSet()).add(limit)
else:
self._decorated_limits[route] = OrderedSet([limit])

def add_decorated_static_limit(
self, route: str, *limits: Limit, override: bool = False
def add_runtime_blueprint_limit(
self, blueprint: str, limit: Optional[LimitGroup]
) -> None:
if not override:
self._static_decorated_limits.setdefault(route, OrderedSet()).update(
OrderedSet(limits)
if limit:
self._runtime_blueprint_limits.setdefault(blueprint, OrderedSet()).add(
limit
)
else:
self._static_decorated_limits[route] = OrderedSet(limits)

def add_static_blueprint_limits(self, blueprint: str, *limits: Limit) -> None:
self._static_blueprint_limits.setdefault(blueprint, OrderedSet()).update(
Expand Down Expand Up @@ -190,11 +184,8 @@ def exemption_scope(
def decorated_limits(self, callable_name: str) -> List[Limit]:
limits = []
if not self._route_exemptions[callable_name]:
for limit in self._static_decorated_limits.get(callable_name, []):
limits.append(limit)

if callable_name in self._runtime_decorated_limits:
for group in self._runtime_decorated_limits[callable_name]:
if callable_name in self._decorated_limits:
for group in self._decorated_limits[callable_name]:
try:
for limit in group:
limits.append(limit)
Expand Down

0 comments on commit d52d662

Please sign in to comment.