Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateless dialogs #140

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions aiogram_dialog/context/intent_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional, Type, Dict, Union, Any
from typing import Optional, Type, Dict, Union, Any, Protocol

from aiogram.dispatcher.filters import BoundFilter
from aiogram.dispatcher.filters.state import StatesGroup
Expand All @@ -23,11 +23,15 @@
logger = getLogger(__name__)


class StorageProxyFactoryProtocol(Protocol):
def __call__(self, user_id: int, chat_id: int):
raise NotImplementedError


class IntentFilter(BoundFilter):
key = 'aiogd_intent_state_group'

def __init__(self,
aiogd_intent_state_group: Optional[Type[StatesGroup]] = None):
def __init__(self, aiogd_intent_state_group: Optional[Type[StatesGroup]] = None):
self.intent_state_group = aiogd_intent_state_group

async def check(self, obj: TelegramObject):
Expand All @@ -41,21 +45,16 @@ async def check(self, obj: TelegramObject):


class IntentMiddleware(BaseMiddleware):
def __init__(self, storage: BaseStorage,
state_groups: Dict[str, Type[StatesGroup]]):
def __init__(self, storage_proxy_factory: StorageProxyFactoryProtocol):
super().__init__()
self.storage = storage
self.state_groups = state_groups
self.storage_proxy_factory = storage_proxy_factory

async def on_pre_process_message(self,
event: Union[Message, ChatMemberUpdated],
data: dict):
chat = get_chat(event)
proxy = StorageProxy(
storage=self.storage,
user_id=event.from_user.id,
chat_id=chat.id,
state_groups=self.state_groups,
proxy = self.storage_proxy_factory(
user_id=event.from_user.id, chat_id=chat.id,
)
stack = await proxy.load_stack()
if stack.empty():
Expand All @@ -74,11 +73,8 @@ async def on_post_process_message(self, _, result, data: dict):
async def on_pre_process_aiogd_update(self, event: DialogUpdateEvent,
data: dict):
chat = get_chat(event)
proxy = StorageProxy(
storage=self.storage,
user_id=event.from_user.id,
chat_id=chat.id,
state_groups=self.state_groups,
proxy = self.storage_proxy_factory(
user_id=event.from_user.id, chat_id=chat.id,
)
data[STORAGE_KEY] = proxy
if event.intent_id is not None:
Expand Down Expand Up @@ -112,11 +108,8 @@ async def on_pre_process_aiogd_update(self, event: DialogUpdateEvent,
async def on_pre_process_callback_query(self, event: CallbackQuery,
data: dict):
chat = get_chat(event)
proxy = StorageProxy(
storage=self.storage,
user_id=event.from_user.id,
chat_id=chat.id,
state_groups=self.state_groups,
proxy = self.storage_proxy_factory(
user_id=event.from_user.id, chat_id=chat.id,
)
data[STORAGE_KEY] = proxy

Expand Down Expand Up @@ -162,11 +155,8 @@ async def on_pre_process_error(self, update: Update, error: Exception,

chat = get_chat(event)

proxy = StorageProxy(
storage=self.storage,
user_id=event.from_user.id,
chat_id=chat.id,
state_groups=self.state_groups,
proxy = self.storage_proxy_factory(
user_id=event.from_user.id, chat_id=chat.id,
)
data[STORAGE_KEY] = proxy

Expand Down
27 changes: 3 additions & 24 deletions aiogram_dialog/context/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,10 @@

DEFAULT_STACK_ID = ""
STACK_LIMIT = 100
ID_SYMS = string.digits + string.ascii_letters


def new_int_id() -> int:
return int(time.time()) % 100000000 + random.randint(0, 99) * 100000000


def id_to_str(int_id: int) -> str:
if not int_id:
return ID_SYMS[0]
base = len(ID_SYMS)
res = ""
while int_id:
int_id, mod = divmod(int_id, base)
res += ID_SYMS[mod]
return res


def new_id():
return id_to_str(new_int_id())


@dataclass(unsafe_hash=True)
class Stack:
_id: str = field(compare=True, default_factory=new_id)
_id: str = field(compare=True)
intents: List[str] = field(compare=False, default_factory=list)
last_message_id: Optional[int] = field(compare=False, default=None)
last_media_id: Optional[str] = field(compare=False, default=None)
Expand All @@ -47,13 +26,13 @@ class Stack:
def id(self):
return self._id

def push(self, state: State, data: Data) -> Context:
def push(self, state: State, intent_id: str, data: Data) -> Context:
if len(self.intents) >= STACK_LIMIT:
raise DialogStackOverflow(
f"Cannot open more dialogs in current stack. Max count is {STACK_LIMIT}"
)
context = Context(
_intent_id=new_id(),
_intent_id=intent_id,
_stack_id=self.id,
state=state,
start_data=data,
Expand Down
78 changes: 78 additions & 0 deletions aiogram_dialog/context/stateless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, Type, Optional

from aiogram.dispatcher.filters.state import State, StatesGroup

from .context import Context
from .stack import Stack, DEFAULT_STACK_ID
from ..exceptions import UnknownState


class StatelessContext(Context):
@property
def id(self):
return self.state.state


class StatelessStorageProxy:
def __init__(self,
user_id: int, chat_id: int,
state_groups: Dict[str, Type[StatesGroup]]):
self.state_groups = state_groups
self.user_id = user_id
self.chat_id = chat_id
self.intent_id = None

async def new_intent_id(self, state: State) -> str:
return state.state

async def new_stack_id(self) -> str:
raise DEFAULT_STACK_ID

async def load_context(self, intent_id: str) -> Context:
self.intent_id = intent_id
return StatelessContext(
_intent_id=intent_id,
_stack_id=DEFAULT_STACK_ID,
state=self._state(intent_id),
start_data=None,
dialog_data={},
widget_data={},
)

async def load_stack(self, stack_id: str = DEFAULT_STACK_ID) -> Stack:
if self.intent_id:
intents = [self.intent_id]
else:
intents = []
return Stack(_id=stack_id, intents=intents)

async def save_context(self, context: Optional[Context]) -> None:
pass

async def remove_context(self, intent_id: str):
pass

async def remove_stack(self, stack_id: str):
pass

async def save_stack(self, stack: Optional[Stack]) -> None:
pass

def _state(self, state: str) -> State:
group, *_ = state.partition(":")
for real_state in self.state_groups[group].all_states:
if real_state.state == state:
return real_state
raise UnknownState(f"Unknown state {state}")


class StatelessStorageProxyFactory:
def __init__(self, state_groups: Dict[str, Type[StatesGroup]]):
self.state_groups = state_groups

def __call__(self, user_id: int, chat_id: int) -> StatelessStorageProxy:
return StatelessStorageProxy(
user_id=user_id,
chat_id=chat_id,
state_groups=self.state_groups,
)
59 changes: 56 additions & 3 deletions aiogram_dialog/context/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random
import string
import time
from copy import copy
from typing import Dict, Type, Optional

Expand All @@ -8,6 +11,27 @@
from .stack import Stack, DEFAULT_STACK_ID
from ..exceptions import UnknownState, UnknownIntent

ID_SYMS = string.digits + string.ascii_letters


def new_int_id() -> int:
return int(time.time()) % 100000000 + random.randint(0, 99) * 100000000


def id_to_str(int_id: int) -> str:
if not int_id:
return ID_SYMS[0]
base = len(ID_SYMS)
res = ""
while int_id:
int_id, mod = divmod(int_id, base)
res += ID_SYMS[mod]
return res


def new_id():
return id_to_str(new_int_id())


class StorageProxy:
def __init__(self, storage: BaseStorage,
Expand All @@ -18,13 +42,20 @@ def __init__(self, storage: BaseStorage,
self.user_id = user_id
self.chat_id = chat_id

async def new_intent_id(self, state: State) -> str:
return new_id()

async def new_stack_id(self) -> str:
return new_id()

async def load_context(self, intent_id: str) -> Context:
data = await self.storage.get_data(
chat=self.chat_id,
user=self._context_key(intent_id)
)
if not data:
raise UnknownIntent(f"Context not found for intent id: {intent_id}")
raise UnknownIntent(
f"Context not found for intent id: {intent_id}")
data["state"] = self._state(data["state"])
return Context(**data)

Expand All @@ -49,10 +80,12 @@ async def save_context(self, context: Optional[Context]) -> None:
)

async def remove_context(self, intent_id: str):
await self.storage.reset_data(chat=self.chat_id, user=self._context_key(intent_id))
await self.storage.reset_data(chat=self.chat_id,
user=self._context_key(intent_id))

async def remove_stack(self, stack_id: str):
await self.storage.reset_data(chat=self.chat_id, user=self._stack_key(stack_id))
await self.storage.reset_data(chat=self.chat_id,
user=self._stack_key(stack_id))

async def save_stack(self, stack: Optional[Stack]) -> None:
if not stack:
Expand Down Expand Up @@ -82,3 +115,23 @@ def _state(self, state: str) -> State:
if real_state.state == state:
return real_state
raise UnknownState(f"Unknown state {state}")


class StorageProxyFactory:
def __init__(self,
storage: BaseStorage,
state_groups: Dict[str, Type[StatesGroup]]):
self.storage = storage
self.state_groups = state_groups

def __call__(
self,
user_id: int, chat_id: int,
state_groups: Dict[str, Type[StatesGroup]]
):
return StorageProxy(
storage=self.storage,
user_id=user_id,
chat_id=chat_id,
state_groups=self.state_groups,
)
5 changes: 3 additions & 2 deletions aiogram_dialog/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def reset_stack(self, remove_keyboard: bool = True) -> None:
self.data[CONTEXT_KEY] = None

async def _start_new_stack(self, state: State, data: Data = None) -> None:
stack = Stack()
stack = Stack(_id=await self.storage().new_stack_id())
await self.bg(stack_id=stack.id).start(state, data, StartMode.NORMAL)

async def _start_normal(self, state: State, data: Data = None) -> None:
Expand All @@ -145,7 +145,8 @@ async def _start_normal(self, state: State, data: Data = None) -> None:
await self.storage().remove_context(stack.pop())

await self.storage().save_context(self.current_context())
context = stack.push(state, data)
new_intent_id = await self.storage().new_intent_id(state)
context = stack.push(state, new_intent_id, data)
self.data[CONTEXT_KEY] = context
await self._dialog().process_start(self, data, state)
if context.id == self.current_context().id:
Expand Down
15 changes: 12 additions & 3 deletions aiogram_dialog/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from ..context.events import DialogUpdateEvent, StartMode
from ..context.intent_filter import IntentFilter, IntentMiddleware
from ..context.media_storage import MediaIdStorage
from ..context.stateless import StatelessStorageProxyFactory
from ..context.storage import StorageProxyFactory
from ..exceptions import UnregisteredDialogError


Expand All @@ -25,6 +27,7 @@ def __init__(
dp: Dispatcher,
dialogs: Sequence[ManagedDialogProto] = (),
media_id_storage: Optional[MediaIdStorageProtocol] = None,
stateless: bool = False,
):
self.dp = dp
self.dialogs = {
Expand All @@ -33,6 +36,14 @@ def __init__(
self.state_groups: Dict[str, Type[StatesGroup]] = {
d.states_group_name(): d.states_group() for d in dialogs
}
if stateless:
self.storage_proxy_factory = StatelessStorageProxyFactory(
state_groups=self.state_groups,
)
else:
self.storage_proxy_factory = StorageProxyFactory(
storage=dp.storage, state_groups=self.state_groups
)
self.update_handler = Handler(dp, middleware_key="aiogd_update")
self.register_update_handler(handle_update, state="*")
self.dp.filters_factory.bind(IntentFilter)
Expand Down Expand Up @@ -68,9 +79,7 @@ def _register_middleware(self):
self.dp.setup_middleware(
ManagerMiddleware(self)
)
self.dp.setup_middleware(
IntentMiddleware(storage=self.dp.storage, state_groups=self.state_groups)
)
self.dp.setup_middleware(IntentMiddleware(self.storage_proxy_factory))

def find_dialog(self, state: State) -> ManagedDialogProto:
try:
Expand Down