Skip to content

Commit

Permalink
Added ability to specify which update bot need to receive and process…
Browse files Browse the repository at this point in the history
… while using polling mode (#617)

* provide allowed_updates in polling mode
  • Loading branch information
Forevka committed Jul 4, 2021
1 parent eee6589 commit 125fc22
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ aiogram/_meta.py
.coverage
reports

dev/
dev/
.venv/
16 changes: 13 additions & 3 deletions aiogram/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextvars
import warnings
from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union

from .. import loggers
from ..client.bot import Bot
Expand Down Expand Up @@ -130,14 +130,15 @@ async def _listen_updates(
bot: Bot,
polling_timeout: int = 30,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
) -> AsyncGenerator[Update, None]:
"""
Endless updates reader with correctly handling any server-side or connection errors.
So you may not worry that the polling will stop working.
"""
backoff = Backoff(config=backoff_config)
get_updates = GetUpdates(timeout=polling_timeout)
get_updates = GetUpdates(timeout=polling_timeout, allowed_updates=allowed_updates)
kwargs = {}
if bot.session.timeout:
# Request timeout can be lower than session timeout ant that's OK.
Expand Down Expand Up @@ -297,6 +298,7 @@ async def _polling(
polling_timeout: int = 30,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -307,7 +309,10 @@ async def _polling(
:return:
"""
async for update in self._listen_updates(
bot, polling_timeout=polling_timeout, backoff_config=backoff_config
bot,
polling_timeout=polling_timeout,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
):
handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks:
Expand Down Expand Up @@ -397,6 +402,7 @@ async def start_polling(
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -427,6 +433,7 @@ async def start_polling(
handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
**kwargs,
)
)
Expand All @@ -443,6 +450,7 @@ def run_polling(
polling_timeout: int = 30,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -452,6 +460,7 @@ def run_polling(
:param polling_timeout: Poling timeout
:param backoff_config:
:param handle_as_tasks: Run task for each event and no wait result
:param allowed_updates: List of the update types you want your bot to receive
:param kwargs: contextual data
:return:
"""
Expand All @@ -463,6 +472,7 @@ def run_polling(
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
)
)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
Expand Down
28 changes: 28 additions & 0 deletions aiogram/utils/handlers_in_use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from itertools import chain
from typing import List, cast

from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.router import Router

INTERNAL_HANDLERS = [
"update",
"error",
]


def get_handlers_in_use(
dispatcher: Dispatcher, handlers_to_skip: List[str] = INTERNAL_HANDLERS
) -> List[str]:
handlers_in_use: List[str] = []

for router in [dispatcher.sub_routers, dispatcher]:
if isinstance(router, list):
if router:
handlers_in_use.extend(chain(*list(map(get_handlers_in_use, router))))
else:
router = cast(Router, router)
for update_name, observer in router.observers.items():
if observer.handlers and update_name not in [*handlers_to_skip, *handlers_in_use]:
handlers_in_use.append(update_name)

return handlers_in_use
87 changes: 87 additions & 0 deletions examples/specify_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from aiogram.types.inline_keyboard_button import InlineKeyboardButton
from aiogram.types.inline_keyboard_markup import InlineKeyboardMarkup
from aiogram.dispatcher.router import Router
from aiogram.utils.handlers_in_use import get_handlers_in_use
import logging

from aiogram import Bot, Dispatcher
from aiogram.types import Message, ChatMemberUpdated, CallbackQuery

TOKEN = "6wo"
dp = Dispatcher()

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dp.message(commands={"start"})
async def command_start_handler(message: Message) -> None:
"""
This handler receive messages with `/start` command
"""

await message.answer(
f"Hello, <b>{message.from_user.full_name}!</b>",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]]
),
)


@dp.chat_member()
async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message(
chat_member.chat.id,
"Member {chat_member.from_user.id} was changed "
+ f"from {chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}",
)


# this router will use only callback_query updates
sub_router = Router()


@sub_router.callback_query()
async def callback_tap_me(callback_query: CallbackQuery) -> None:
await callback_query.answer("Yeah good, now i'm fine")


# this router will use only edited_message updates
sub_sub_router = Router()


@sub_sub_router.edited_message()
async def callback_tap_me(edited_message: Message) -> None:
await edited_message.reply("Message was edited, big brother watch you")


# this router will use only my_chat_member updates
deep_dark_router = Router()


@deep_dark_router.my_chat_member()
async def my_chat_member_change(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message(
chat_member.chat.id,
"Member was changed from "
+ f"{chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}",
)


def main() -> None:
# Initialize Bot instance with an default parse mode which will be passed to all API calls
bot = Bot(TOKEN, parse_mode="HTML")

sub_router.include_router(deep_dark_router)

dp.include_router(sub_router)
dp.include_router(sub_sub_router)

useful_updates = get_handlers_in_use(dp)

# And the run events dispatching
dp.run_polling(bot, allowed_updates=useful_updates)


if __name__ == "__main__":
main()
54 changes: 54 additions & 0 deletions tests/test_dispatcher/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Update,
User,
)
from aiogram.utils.handlers_in_use import get_handlers_in_use
from tests.mocked_bot import MockedBot

try:
Expand Down Expand Up @@ -659,3 +660,56 @@ async def test_feed_webhook_update_fast_process_error(self, bot: MockedBot, capl

log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0]

def test_specify_updates_calculation(self):
def simple_msg_handler() -> None:
...

def simple_callback_query_handler() -> None:
...

def simple_poll_handler() -> None:
...

def simple_edited_msg_handler() -> None:
...

dispatcher = Dispatcher()
dispatcher.message.register(simple_msg_handler)

router1 = Router()
router1.callback_query.register(simple_callback_query_handler)

router2 = Router()
router2.poll.register(simple_poll_handler)

router21 = Router()
router21.edited_message.register(simple_edited_msg_handler)

useful_updates1 = get_handlers_in_use(dispatcher)

assert sorted(useful_updates1) == sorted(["message"])

dispatcher.include_router(router1)

useful_updates2 = get_handlers_in_use(dispatcher)

assert sorted(useful_updates2) == sorted(["message", "callback_query"])

dispatcher.include_router(router2)

useful_updates3 = get_handlers_in_use(dispatcher)

assert sorted(useful_updates3) == sorted(["message", "callback_query", "poll"])

router2.include_router(router21)

useful_updates4 = get_handlers_in_use(dispatcher)

assert sorted(useful_updates4) == sorted(
["message", "callback_query", "poll", "edited_message"]
)

useful_updates5 = get_handlers_in_use(router2)

assert sorted(useful_updates5) == sorted(["poll", "edited_message"])

0 comments on commit 125fc22

Please sign in to comment.