Skip to content

Commit

Permalink
Global filters for router (#644)
Browse files Browse the repository at this point in the history
* Bump version

* Added more comments

* Cover registering global filters

* Reformat code

* Add more tests

* Rework event propagation to routers mechanism. Fixed compatibility with Python 3.10 syntax (match keyword)

* Fixed tests

* Fixed coverage

Co-authored-by: evgfilim1 <evgfilim1@yandex.ru>
  • Loading branch information
JrooTJunior and evgfilim1 committed Jul 31, 2021
1 parent a70ecb7 commit 4f2cc75
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 31 deletions.
13 changes: 2 additions & 11 deletions aiogram/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,11 @@ async def _listen_update(self, update: Update, **kwargs: Any) -> Any:
"installed not latest version of aiogram framework",
RuntimeWarning,
)
raise SkipHandler
raise SkipHandler()

kwargs.update(event_update=update)

for router in self.chain:
kwargs.update(event_router=router)
observer = router.observers[update_type]
response = await observer.trigger(event, update=update, **kwargs)
if response is not UNHANDLED:
break
else:
response = UNHANDLED

return response
return await self.propagate_event(update_type=update_type, event=event, **kwargs)

@classmethod
async def _silent_call_request(cls, bot: Bot, result: TelegramMethod[Any]) -> None:
Expand Down
1 change: 1 addition & 0 deletions aiogram/dispatcher/event/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
]

UNHANDLED = sentinel.UNHANDLED
REJECTED = sentinel.REJECTED


class SkipHandler(Exception):
Expand Down
26 changes: 25 additions & 1 deletion aiogram/dispatcher/event/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ...types import TelegramObject
from ..filters.base import BaseFilter
from .bases import UNHANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
from .bases import REJECTED, UNHANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -32,6 +32,24 @@ def __init__(self, router: Router, event_name: str) -> None:
self.outer_middlewares: List[MiddlewareType] = []
self.middlewares: List[MiddlewareType] = []

# Re-used filters check method from already implemented handler object
# with dummy callback which never will be used
self._handler = HandlerObject(callback=lambda: True, filters=[])

def filter(self, *filters: FilterType, **bound_filters: Any) -> None:
"""
Register filter for all handlers of this event observer
:param filters: positional filters
:param bound_filters: keyword filters
"""
resolved_filters = self.resolve_filters(bound_filters)
if self._handler.filters is None:
self._handler.filters = []
self._handler.filters.extend(
[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)]
)

def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
"""
Register filter class in factory
Expand Down Expand Up @@ -139,6 +157,12 @@ async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
return await wrapped_outer(event, kwargs)

async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
# Check globally defined filters before any other handler will be checked
result, data = await self._handler.check(event, **kwargs)
if not result:
return REJECTED
kwargs.update(data)

for handler in self.handlers:
result, data = await handler.check(event, **kwargs)
if result:
Expand Down
4 changes: 2 additions & 2 deletions aiogram/dispatcher/filters/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate_command(self, command: CommandObject) -> CommandObject:
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command.command)
if result:
return replace(command, match=result)
return replace(command, regexp_match=result)
elif command.command == allowed_command: # String
return command
raise CommandException("Command did not match pattern")
Expand Down Expand Up @@ -134,7 +134,7 @@ class CommandObject:
"""Mention (if available)"""
args: Optional[str] = field(repr=False, default=None)
"""Command argument"""
match: Optional[Match[str]] = field(repr=False, default=None)
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
"""Will be presented match result if the command is presented as regexp in filter"""

@property
Expand Down
6 changes: 3 additions & 3 deletions aiogram/dispatcher/filters/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ class ExceptionMessageFilter(BaseFilter):
Allow to match exception by message
"""

match: Union[str, Pattern[str]]
pattern: Union[str, Pattern[str]]
"""Regexp pattern"""

class Config:
arbitrary_types_allowed = True

@validator("match")
@validator("pattern")
def _validate_match(cls, value: Union[str, Pattern[str]]) -> Union[str, Pattern[str]]:
if isinstance(value, str):
return re.compile(value)
return value

async def __call__(self, exception: Exception) -> Union[bool, Dict[str, Any]]:
pattern = cast(Pattern[str], self.match)
pattern = cast(Pattern[str], self.pattern)
result = pattern.match(str(exception))
if not result:
return False
Expand Down
6 changes: 3 additions & 3 deletions aiogram/dispatcher/fsm/storage/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def from_url(
return cls(redis=redis, **kwargs)

async def close(self) -> None:
await self.redis.close()
await self.redis.close() # type: ignore

def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
Expand Down Expand Up @@ -73,7 +73,7 @@ async def set_state(
await self.redis.delete(key)
else:
await self.redis.set(
key, state.state if isinstance(state, State) else state, ex=self.state_ttl
key, state.state if isinstance(state, State) else state, ex=self.state_ttl # type: ignore[arg-type]
)

async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
Expand All @@ -89,7 +89,7 @@ async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, A
await self.redis.delete(key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl)
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]

async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
Expand Down
18 changes: 18 additions & 0 deletions aiogram/dispatcher/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import warnings
from typing import Any, Dict, Generator, List, Optional, Union

from ..types import TelegramObject
from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
from .event.telegram import TelegramEventObserver
from .filters import BUILTIN_FILTERS
Expand Down Expand Up @@ -82,6 +84,22 @@ def __init__(self, use_builtin_filters: bool = True) -> None:
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)

async def propagate_event(self, update_type: str, event: TelegramObject, **kwargs: Any) -> Any:
kwargs.update(event_router=self)
observer = self.observers[update_type]
response = await observer.trigger(event, **kwargs)
if response is REJECTED:
return UNHANDLED
if response is not UNHANDLED:
return response

for router in self.sub_routers:
response = await router.propagate_event(update_type=update_type, event=event, **kwargs)
if response is not UNHANDLED:
break

return response

@property
def chain_head(self) -> Generator[Router, None, None]:
router: Optional[Router] = self
Expand Down
59 changes: 55 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ pydantic = "^1.8.1"
Babel = "^2.9.1"
aiofiles = "^0.6.0"
async_lru = "^1.0.2"
frozenlist = "^1.1.1"
aiohttp-socks = { version = "^0.5.5", optional = true }
aioredis = { version = "^2.0.0a1", allow-prereleases = true, optional = true }
aioredis = { version = "^2.0.0", allow-prereleases = true, optional = true }
magic-filter = { version = "1.0.0a1", allow-prereleases = true }
sphinx = { version = "^3.1.0", optional = true }
sphinx-intl = { version = "^2.0.1", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api/test_methods/test_delete_my_commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from aiogram.methods import BanChatMember, DeleteMyCommands, Request
from aiogram.methods import DeleteMyCommands, Request
from tests.mocked_bot import MockedBot


Expand Down
47 changes: 46 additions & 1 deletion tests/test_dispatcher/test_event/test_telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from aiogram.dispatcher.event.bases import SkipHandler
from aiogram.dispatcher.event.bases import REJECTED, SkipHandler
from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.filters.base import BaseFilter
Expand Down Expand Up @@ -233,3 +233,48 @@ async def my_middleware3(handler, event, data):
assert my_middleware3 in middlewares

assert middlewares == [my_middleware1, my_middleware2, my_middleware3]

def test_register_global_filters(self):
router = Router(use_builtin_filters=False)
assert isinstance(router.message._handler.filters, list)
assert not router.message._handler.filters

my_filter = MyFilter1(test="pass")
router.message.filter(my_filter)

assert len(router.message._handler.filters) == 1
assert router.message._handler.filters[0].callback is my_filter

router.message._handler.filters = None
router.message.filter(my_filter)
assert len(router.message._handler.filters) == 1
assert router.message._handler.filters[0].callback is my_filter

@pytest.mark.asyncio
async def test_global_filter(self):
r1 = Router()
r2 = Router()

async def handler(evt):
return evt

r1.message.filter(lambda evt: False)
r1.message.register(handler)
r2.message.register(handler)

assert await r1.message.trigger(None) is REJECTED
assert await r2.message.trigger(None) is None

@pytest.mark.asyncio
async def test_global_filter_in_nested_router(self):
r1 = Router()
r2 = Router()

async def handler(evt):
return evt

r1.include_router(r2)
r1.message.filter(lambda evt: False)
r2.message.register(handler)

assert await r1.message.trigger(None) is REJECTED
6 changes: 3 additions & 3 deletions tests/test_dispatcher/test_filters/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
class TestExceptionMessageFilter:
@pytest.mark.parametrize("value", ["value", re.compile("value")])
def test_converter(self, value):
obj = ExceptionMessageFilter(match=value)
assert isinstance(obj.match, re.Pattern)
obj = ExceptionMessageFilter(pattern=value)
assert isinstance(obj.pattern, re.Pattern)

@pytest.mark.asyncio
async def test_match(self):
obj = ExceptionMessageFilter(match="KABOOM")
obj = ExceptionMessageFilter(pattern="KABOOM")

result = await obj(Exception())
assert not result
Expand Down
16 changes: 15 additions & 1 deletion tests/test_dispatcher/test_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from aiogram.dispatcher.event.bases import SkipHandler, skip
from aiogram.dispatcher.event.bases import SkipHandler, skip, UNHANDLED
from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect

Expand Down Expand Up @@ -122,3 +122,17 @@ def test_skip(self):
skip()
with pytest.raises(SkipHandler, match="KABOOM"):
skip("KABOOM")

@pytest.mark.asyncio
async def test_global_filter_in_nested_router(self):
r1 = Router()
r2 = Router()

async def handler(evt):
return evt

r1.include_router(r2)
r1.message.filter(lambda evt: False)
r2.message.register(handler)

assert await r1.propagate_event(update_type="message", event=None) is UNHANDLED

0 comments on commit 4f2cc75

Please sign in to comment.