From 0c7994bd87825b457a3153f2cf4f9e3e09792c85 Mon Sep 17 00:00:00 2001 From: Tishka17 Date: Sun, 24 Apr 2022 22:27:52 +0300 Subject: [PATCH 1/3] fake state storage proxy --- aiogram_dialog/context/stack.py | 27 ++------------ aiogram_dialog/context/stateless.py | 56 +++++++++++++++++++++++++++++ aiogram_dialog/context/storage.py | 31 ++++++++++++++++ aiogram_dialog/manager/manager.py | 5 +-- 4 files changed, 93 insertions(+), 26 deletions(-) create mode 100644 aiogram_dialog/context/stateless.py diff --git a/aiogram_dialog/context/stack.py b/aiogram_dialog/context/stack.py index 9381debd..cd046ba0 100644 --- a/aiogram_dialog/context/stack.py +++ b/aiogram_dialog/context/stack.py @@ -12,31 +12,10 @@ DEFAULT_STACK_ID = "" STACK_LIMIT = 100 -ID_SYMS = string.digits + string.ascii_letters - - -def new_int_id() -> int: - return int(time.time()) % 100000000 + random.randint(0, 99) * 100000000 - - -def id_to_str(int_id: int) -> str: - if not int_id: - return ID_SYMS[0] - base = len(ID_SYMS) - res = "" - while int_id: - int_id, mod = divmod(int_id, base) - res += ID_SYMS[mod] - return res - - -def new_id(): - return id_to_str(new_int_id()) - @dataclass(unsafe_hash=True) class Stack: - _id: str = field(compare=True, default_factory=new_id) + _id: str = field(compare=True) intents: List[str] = field(compare=False, default_factory=list) last_message_id: Optional[int] = field(compare=False, default=None) last_media_id: Optional[str] = field(compare=False, default=None) @@ -47,13 +26,13 @@ class Stack: def id(self): return self._id - def push(self, state: State, data: Data) -> Context: + def push(self, state: State, intent_id: str, data: Data) -> Context: if len(self.intents) >= STACK_LIMIT: raise DialogStackOverflow( f"Cannot open more dialogs in current stack. Max count is {STACK_LIMIT}" ) context = Context( - _intent_id=new_id(), + _intent_id=intent_id, _stack_id=self.id, state=state, start_data=data, diff --git a/aiogram_dialog/context/stateless.py b/aiogram_dialog/context/stateless.py new file mode 100644 index 00000000..cd3a3cff --- /dev/null +++ b/aiogram_dialog/context/stateless.py @@ -0,0 +1,56 @@ +from typing import Dict, Type, Optional + +from aiogram.dispatcher.filters.state import State, StatesGroup +from aiogram.dispatcher.storage import BaseStorage + +from .context import Context +from .stack import Stack, DEFAULT_STACK_ID +from ..exceptions import UnknownState + + +class FakeStorageProxy: + def __init__(self, storage: BaseStorage, + user_id: int, chat_id: int, + state_groups: Dict[str, Type[StatesGroup]]): + self.storage = storage + self.state_groups = state_groups + self.user_id = user_id + self.chat_id = chat_id + + async def new_intent_id(self, state: State) -> str: + return str(state) + + async def new_stack_id(self) -> str: + raise DEFAULT_STACK_ID + + async def load_context(self, intent_id: str) -> Context: + return Context( + _intent_id=intent_id, + _stack_id=DEFAULT_STACK_ID, + state=self._state(intent_id), + start_data=None, + dialog_data={}, + widget_data={}, + ) + + async def load_stack(self, stack_id: str = DEFAULT_STACK_ID) -> Stack: + return Stack(_id=stack_id) + + async def save_context(self, context: Optional[Context]) -> None: + pass + + async def remove_context(self, intent_id: str): + pass + + async def remove_stack(self, stack_id: str): + pass + + async def save_stack(self, stack: Optional[Stack]) -> None: + pass + + def _state(self, state: str) -> State: + group, *_ = state.partition(":") + for real_state in self.state_groups[group].all_states: + if real_state.state == state: + return real_state + raise UnknownState(f"Unknown state {state}") diff --git a/aiogram_dialog/context/storage.py b/aiogram_dialog/context/storage.py index 6f243355..43d98ee4 100644 --- a/aiogram_dialog/context/storage.py +++ b/aiogram_dialog/context/storage.py @@ -1,3 +1,6 @@ +import random +import string +import time from copy import copy from typing import Dict, Type, Optional @@ -9,6 +12,28 @@ from ..exceptions import UnknownState, UnknownIntent +ID_SYMS = string.digits + string.ascii_letters + + +def new_int_id() -> int: + return int(time.time()) % 100000000 + random.randint(0, 99) * 100000000 + + +def id_to_str(int_id: int) -> str: + if not int_id: + return ID_SYMS[0] + base = len(ID_SYMS) + res = "" + while int_id: + int_id, mod = divmod(int_id, base) + res += ID_SYMS[mod] + return res + + +def new_id(): + return id_to_str(new_int_id()) + + class StorageProxy: def __init__(self, storage: BaseStorage, user_id: int, chat_id: int, @@ -18,6 +43,12 @@ def __init__(self, storage: BaseStorage, self.user_id = user_id self.chat_id = chat_id + async def new_intent_id(self, state: State) -> str: + return new_id() + + async def new_stack_id(self) -> str: + return new_id() + async def load_context(self, intent_id: str) -> Context: data = await self.storage.get_data( chat=self.chat_id, diff --git a/aiogram_dialog/manager/manager.py b/aiogram_dialog/manager/manager.py index 5d3d006d..0e381690 100644 --- a/aiogram_dialog/manager/manager.py +++ b/aiogram_dialog/manager/manager.py @@ -123,7 +123,7 @@ async def reset_stack(self, remove_keyboard: bool = True) -> None: self.data[CONTEXT_KEY] = None async def _start_new_stack(self, state: State, data: Data = None) -> None: - stack = Stack() + stack = Stack(_id=await self.storage().new_stack_id()) await self.bg(stack_id=stack.id).start(state, data, StartMode.NORMAL) async def _start_normal(self, state: State, data: Data = None) -> None: @@ -145,7 +145,8 @@ async def _start_normal(self, state: State, data: Data = None) -> None: await self.storage().remove_context(stack.pop()) await self.storage().save_context(self.current_context()) - context = stack.push(state, data) + new_intent_id = await self.storage().new_intent_id(state) + context = stack.push(state, new_intent_id, data) self.data[CONTEXT_KEY] = context await self._dialog().process_start(self, data, state) if context.id == self.current_context().id: From ac52bc3b9f6ec43e7a01b960a5c72b76672ceca6 Mon Sep 17 00:00:00 2001 From: Tishka17 Date: Sun, 24 Apr 2022 22:37:18 +0300 Subject: [PATCH 2/3] storage proxy factory --- aiogram_dialog/context/intent_filter.py | 44 ++++++++++--------------- aiogram_dialog/context/stateless.py | 22 ++++++++++--- aiogram_dialog/context/storage.py | 30 ++++++++++++++--- aiogram_dialog/manager/registry.py | 14 +++++--- 4 files changed, 71 insertions(+), 39 deletions(-) diff --git a/aiogram_dialog/context/intent_filter.py b/aiogram_dialog/context/intent_filter.py index 36853c75..db23fbc5 100644 --- a/aiogram_dialog/context/intent_filter.py +++ b/aiogram_dialog/context/intent_filter.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Optional, Type, Dict, Union, Any +from typing import Optional, Type, Dict, Union, Any, Protocol from aiogram.dispatcher.filters import BoundFilter from aiogram.dispatcher.filters.state import StatesGroup @@ -23,11 +23,15 @@ logger = getLogger(__name__) +class StorageProxyFactoryProtocol(Protocol): + def __call__(self, user_id: int, chat_id: int): + raise NotImplementedError + + class IntentFilter(BoundFilter): key = 'aiogd_intent_state_group' - def __init__(self, - aiogd_intent_state_group: Optional[Type[StatesGroup]] = None): + def __init__(self, aiogd_intent_state_group: Optional[Type[StatesGroup]] = None): self.intent_state_group = aiogd_intent_state_group async def check(self, obj: TelegramObject): @@ -41,21 +45,16 @@ async def check(self, obj: TelegramObject): class IntentMiddleware(BaseMiddleware): - def __init__(self, storage: BaseStorage, - state_groups: Dict[str, Type[StatesGroup]]): + def __init__(self, storage_proxy_factory: StorageProxyFactoryProtocol): super().__init__() - self.storage = storage - self.state_groups = state_groups + self.storage_proxy_factory = storage_proxy_factory async def on_pre_process_message(self, event: Union[Message, ChatMemberUpdated], data: dict): chat = get_chat(event) - proxy = StorageProxy( - storage=self.storage, - user_id=event.from_user.id, - chat_id=chat.id, - state_groups=self.state_groups, + proxy = self.storage_proxy_factory( + user_id=event.from_user.id, chat_id=chat.id, ) stack = await proxy.load_stack() if stack.empty(): @@ -74,11 +73,8 @@ async def on_post_process_message(self, _, result, data: dict): async def on_pre_process_aiogd_update(self, event: DialogUpdateEvent, data: dict): chat = get_chat(event) - proxy = StorageProxy( - storage=self.storage, - user_id=event.from_user.id, - chat_id=chat.id, - state_groups=self.state_groups, + proxy = self.storage_proxy_factory( + user_id=event.from_user.id, chat_id=chat.id, ) data[STORAGE_KEY] = proxy if event.intent_id is not None: @@ -112,11 +108,8 @@ async def on_pre_process_aiogd_update(self, event: DialogUpdateEvent, async def on_pre_process_callback_query(self, event: CallbackQuery, data: dict): chat = get_chat(event) - proxy = StorageProxy( - storage=self.storage, - user_id=event.from_user.id, - chat_id=chat.id, - state_groups=self.state_groups, + proxy = self.storage_proxy_factory( + user_id=event.from_user.id, chat_id=chat.id, ) data[STORAGE_KEY] = proxy @@ -162,11 +155,8 @@ async def on_pre_process_error(self, update: Update, error: Exception, chat = get_chat(event) - proxy = StorageProxy( - storage=self.storage, - user_id=event.from_user.id, - chat_id=chat.id, - state_groups=self.state_groups, + proxy = self.storage_proxy_factory( + user_id=event.from_user.id, chat_id=chat.id, ) data[STORAGE_KEY] = proxy diff --git a/aiogram_dialog/context/stateless.py b/aiogram_dialog/context/stateless.py index cd3a3cff..8460b810 100644 --- a/aiogram_dialog/context/stateless.py +++ b/aiogram_dialog/context/stateless.py @@ -1,18 +1,16 @@ from typing import Dict, Type, Optional from aiogram.dispatcher.filters.state import State, StatesGroup -from aiogram.dispatcher.storage import BaseStorage from .context import Context from .stack import Stack, DEFAULT_STACK_ID from ..exceptions import UnknownState -class FakeStorageProxy: - def __init__(self, storage: BaseStorage, +class StatelessStorageProxy: + def __init__(self, user_id: int, chat_id: int, state_groups: Dict[str, Type[StatesGroup]]): - self.storage = storage self.state_groups = state_groups self.user_id = user_id self.chat_id = chat_id @@ -54,3 +52,19 @@ def _state(self, state: str) -> State: if real_state.state == state: return real_state raise UnknownState(f"Unknown state {state}") + + +class StatelessStorageProxyFactory: + def __init__(self, state_groups: Dict[str, Type[StatesGroup]]): + self.state_groups = state_groups + + def __call__( + self, + user_id: int, chat_id: int, + state_groups: Dict[str, Type[StatesGroup]] + ): + return StatelessStorageProxy( + user_id=user_id, + chat_id=chat_id, + state_groups=self.state_groups, + ) diff --git a/aiogram_dialog/context/storage.py b/aiogram_dialog/context/storage.py index 43d98ee4..ff1b7cf6 100644 --- a/aiogram_dialog/context/storage.py +++ b/aiogram_dialog/context/storage.py @@ -11,7 +11,6 @@ from .stack import Stack, DEFAULT_STACK_ID from ..exceptions import UnknownState, UnknownIntent - ID_SYMS = string.digits + string.ascii_letters @@ -55,7 +54,8 @@ async def load_context(self, intent_id: str) -> Context: user=self._context_key(intent_id) ) if not data: - raise UnknownIntent(f"Context not found for intent id: {intent_id}") + raise UnknownIntent( + f"Context not found for intent id: {intent_id}") data["state"] = self._state(data["state"]) return Context(**data) @@ -80,10 +80,12 @@ async def save_context(self, context: Optional[Context]) -> None: ) async def remove_context(self, intent_id: str): - await self.storage.reset_data(chat=self.chat_id, user=self._context_key(intent_id)) + await self.storage.reset_data(chat=self.chat_id, + user=self._context_key(intent_id)) async def remove_stack(self, stack_id: str): - await self.storage.reset_data(chat=self.chat_id, user=self._stack_key(stack_id)) + await self.storage.reset_data(chat=self.chat_id, + user=self._stack_key(stack_id)) async def save_stack(self, stack: Optional[Stack]) -> None: if not stack: @@ -113,3 +115,23 @@ def _state(self, state: str) -> State: if real_state.state == state: return real_state raise UnknownState(f"Unknown state {state}") + + +class StorageProxyFactory: + def __init__(self, + storage: BaseStorage, + state_groups: Dict[str, Type[StatesGroup]]): + self.storage = storage + self.state_groups = state_groups + + def __call__( + self, + user_id: int, chat_id: int, + state_groups: Dict[str, Type[StatesGroup]] + ): + return StorageProxy( + storage=self.storage, + user_id=user_id, + chat_id=chat_id, + state_groups=self.state_groups, + ) diff --git a/aiogram_dialog/manager/registry.py b/aiogram_dialog/manager/registry.py index daca71df..cbd14f32 100644 --- a/aiogram_dialog/manager/registry.py +++ b/aiogram_dialog/manager/registry.py @@ -14,8 +14,10 @@ ) from .update_handler import handle_update from ..context.events import DialogUpdateEvent, StartMode -from ..context.intent_filter import IntentFilter, IntentMiddleware +from ..context.intent_filter import IntentFilter, IntentMiddleware, \ + StorageProxyFactoryProtocol from ..context.media_storage import MediaIdStorage +from ..context.storage import StorageProxyFactory from ..exceptions import UnregisteredDialogError @@ -25,6 +27,7 @@ def __init__( dp: Dispatcher, dialogs: Sequence[ManagedDialogProto] = (), media_id_storage: Optional[MediaIdStorageProtocol] = None, + storage_proxy_factory: Optional[StorageProxyFactoryProtocol] = None, ): self.dp = dp self.dialogs = { @@ -33,6 +36,11 @@ def __init__( self.state_groups: Dict[str, Type[StatesGroup]] = { d.states_group_name(): d.states_group() for d in dialogs } + if storage_proxy_factory is None: + storage_proxy_factory = StorageProxyFactory( + storage=dp.storage, state_groups=self.state_groups + ) + self.storage_proxy_factory = storage_proxy_factory self.update_handler = Handler(dp, middleware_key="aiogd_update") self.register_update_handler(handle_update, state="*") self.dp.filters_factory.bind(IntentFilter) @@ -68,9 +76,7 @@ def _register_middleware(self): self.dp.setup_middleware( ManagerMiddleware(self) ) - self.dp.setup_middleware( - IntentMiddleware(storage=self.dp.storage, state_groups=self.state_groups) - ) + self.dp.setup_middleware(IntentMiddleware(self.storage_proxy_factory)) def find_dialog(self, state: State) -> ManagedDialogProto: try: From 4ecd01b63e1f55d0c3a4452a02304cb9f9c77feb Mon Sep 17 00:00:00 2001 From: Tishka17 Date: Sun, 24 Apr 2022 22:54:08 +0300 Subject: [PATCH 3/3] fix stateless --- aiogram_dialog/context/stateless.py | 24 ++++++++++++++++-------- aiogram_dialog/manager/registry.py | 15 +++++++++------ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/aiogram_dialog/context/stateless.py b/aiogram_dialog/context/stateless.py index 8460b810..4154dfdf 100644 --- a/aiogram_dialog/context/stateless.py +++ b/aiogram_dialog/context/stateless.py @@ -7,6 +7,12 @@ from ..exceptions import UnknownState +class StatelessContext(Context): + @property + def id(self): + return self.state.state + + class StatelessStorageProxy: def __init__(self, user_id: int, chat_id: int, @@ -14,15 +20,17 @@ def __init__(self, self.state_groups = state_groups self.user_id = user_id self.chat_id = chat_id + self.intent_id = None async def new_intent_id(self, state: State) -> str: - return str(state) + return state.state async def new_stack_id(self) -> str: raise DEFAULT_STACK_ID async def load_context(self, intent_id: str) -> Context: - return Context( + self.intent_id = intent_id + return StatelessContext( _intent_id=intent_id, _stack_id=DEFAULT_STACK_ID, state=self._state(intent_id), @@ -32,7 +40,11 @@ async def load_context(self, intent_id: str) -> Context: ) async def load_stack(self, stack_id: str = DEFAULT_STACK_ID) -> Stack: - return Stack(_id=stack_id) + if self.intent_id: + intents = [self.intent_id] + else: + intents = [] + return Stack(_id=stack_id, intents=intents) async def save_context(self, context: Optional[Context]) -> None: pass @@ -58,11 +70,7 @@ class StatelessStorageProxyFactory: def __init__(self, state_groups: Dict[str, Type[StatesGroup]]): self.state_groups = state_groups - def __call__( - self, - user_id: int, chat_id: int, - state_groups: Dict[str, Type[StatesGroup]] - ): + def __call__(self, user_id: int, chat_id: int) -> StatelessStorageProxy: return StatelessStorageProxy( user_id=user_id, chat_id=chat_id, diff --git a/aiogram_dialog/manager/registry.py b/aiogram_dialog/manager/registry.py index cbd14f32..8fc886f8 100644 --- a/aiogram_dialog/manager/registry.py +++ b/aiogram_dialog/manager/registry.py @@ -14,9 +14,9 @@ ) from .update_handler import handle_update from ..context.events import DialogUpdateEvent, StartMode -from ..context.intent_filter import IntentFilter, IntentMiddleware, \ - StorageProxyFactoryProtocol +from ..context.intent_filter import IntentFilter, IntentMiddleware from ..context.media_storage import MediaIdStorage +from ..context.stateless import StatelessStorageProxyFactory from ..context.storage import StorageProxyFactory from ..exceptions import UnregisteredDialogError @@ -27,7 +27,7 @@ def __init__( dp: Dispatcher, dialogs: Sequence[ManagedDialogProto] = (), media_id_storage: Optional[MediaIdStorageProtocol] = None, - storage_proxy_factory: Optional[StorageProxyFactoryProtocol] = None, + stateless: bool = False, ): self.dp = dp self.dialogs = { @@ -36,11 +36,14 @@ def __init__( self.state_groups: Dict[str, Type[StatesGroup]] = { d.states_group_name(): d.states_group() for d in dialogs } - if storage_proxy_factory is None: - storage_proxy_factory = StorageProxyFactory( + if stateless: + self.storage_proxy_factory = StatelessStorageProxyFactory( + state_groups=self.state_groups, + ) + else: + self.storage_proxy_factory = StorageProxyFactory( storage=dp.storage, state_groups=self.state_groups ) - self.storage_proxy_factory = storage_proxy_factory self.update_handler = Handler(dp, middleware_key="aiogd_update") self.register_update_handler(handle_update, state="*") self.dp.filters_factory.bind(IntentFilter)