-
-
Notifications
You must be signed in to change notification settings - Fork 806
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
200 additions
and
79 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import typing | ||
|
||
from aiogram.dispatcher import BaseStorage | ||
|
||
|
||
class MemoryStorage(BaseStorage): | ||
""" | ||
In-memory based states storage. | ||
This type of storage is not recommended for usage in bots, because you will lost all states after restarting. | ||
""" | ||
|
||
def __init__(self): | ||
self.data = {} | ||
|
||
def _get_chat(self, chat_id): | ||
chat_id = str(chat_id) | ||
if chat_id not in self.data: | ||
self.data[chat_id] = {} | ||
return self.data[chat_id] | ||
|
||
def _get_user(self, chat_id, user_id): | ||
chat = self._get_chat(chat_id) | ||
chat_id = str(chat_id) | ||
user_id = str(user_id) | ||
if user_id not in self.data[chat_id]: | ||
self.data[chat_id][user_id] = {'state': None, 'data': {}} | ||
return self.data[chat_id][user_id] | ||
|
||
async def get_state(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
default: typing.Optional[str] = None) -> typing.Optional[str]: | ||
chat, user = self.check_address(chat=chat, user=user) | ||
user = self._get_user(chat, user) | ||
return user['state'] | ||
|
||
async def get_data(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
default: typing.Optional[str] = None) -> typing.Dict: | ||
chat, user = self.check_address(chat=chat, user=user) | ||
user = self._get_user(chat, user) | ||
return user['data'] | ||
|
||
async def update_data(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
data: typing.Dict = None, **kwargs): | ||
chat, user = self.check_address(chat=chat, user=user) | ||
user = self._get_user(chat, user) | ||
if data is None: | ||
data = [] | ||
user['data'].update(data, **kwargs) | ||
|
||
async def set_state(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
state: typing.AnyStr = None): | ||
chat, user = self.check_address(chat=chat, user=user) | ||
user = self._get_user(chat, user) | ||
user['state'] = state | ||
|
||
async def set_data(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
data: typing.Dict = None): | ||
chat, user = self.check_address(chat=chat, user=user) | ||
user = self._get_user(chat, user) | ||
user['data'] = data | ||
|
||
async def reset_state(self, *, | ||
chat: typing.Union[str, int, None] = None, | ||
user: typing.Union[str, int, None] = None, | ||
with_data: typing.Optional[bool] = True): | ||
await self.set_state(chat=chat, user=user, state=None) | ||
if with_data: | ||
await self.set_data(chat=chat, user=user, data={}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver | ||
""" | ||
|
||
import typing | ||
|
||
import aioredis | ||
|
||
from aiogram.utils import json | ||
from ...dispatcher.storage import BaseStorage | ||
|
||
|
||
class RedisStorage(BaseStorage): | ||
""" | ||
Simple Redis-base storage for FSM. | ||
Usage: | ||
.. codeblock:: python3 | ||
storage = RedisStorage('localhost', 6379, db=5) | ||
dp = Dispatcher(bot, storage=storage) | ||
""" | ||
|
||
def __init__(self, host, port, db=None, password=None, ssl=None, loop=None, **kwargs): | ||
self._host = host | ||
self._port = port | ||
self._db: aioredis.RedisConnection = db | ||
self._password = password | ||
self._ssl = ssl | ||
self._loop = loop | ||
self._kwargs = kwargs | ||
|
||
self._redis = None | ||
|
||
@property | ||
async def redis(self) -> aioredis.RedisConnection: | ||
""" | ||
Get Redis connection | ||
This property is awaitable. | ||
:return: | ||
""" | ||
if self._redis is None: | ||
self._redis = await aioredis.create_connection((self._host, self._port), | ||
db=self._db, password=self._password, ssl=self._ssl, | ||
loop=self._loop, | ||
**self._kwargs) | ||
return self._redis | ||
|
||
async def get_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None): | ||
""" | ||
Get record from storage | ||
:param chat: | ||
:param user: | ||
:return: | ||
""" | ||
chat, user = self.check_address(chat=chat, user=user) | ||
addr = f"{chat}:{user}" | ||
|
||
redis = await self.redis | ||
data = await redis.execute('GET', addr) | ||
if data is None: | ||
return {'state': None, 'data': {}} | ||
return json.loads(data) | ||
|
||
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
state=None, data=None): | ||
""" | ||
Write record to storage | ||
:param chat: | ||
:param user: | ||
:param state: | ||
:param data: | ||
:return: | ||
""" | ||
if data is None: | ||
data = {} | ||
|
||
chat, user = self.check_address(chat=chat, user=user) | ||
addr = f"{chat}:{user}" | ||
|
||
record = {'state': state, 'data': data} | ||
|
||
conn = await self.redis | ||
await conn.execute('SET', addr, json.dumps(record)) | ||
|
||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
default: typing.Optional[str] = None) -> typing.Optional[str]: | ||
record = await self.get_record(chat=chat, user=user) | ||
return record['state'] | ||
|
||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
default: typing.Optional[str] = None) -> typing.Dict: | ||
record = await self.get_record(chat=chat, user=user) | ||
return record['data'] | ||
|
||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
state: typing.Optional[typing.AnyStr] = None): | ||
record = await self.get_record(chat=chat, user=user) | ||
await self.set_record(chat=chat, user=user, state=state, data=record['data']) | ||
|
||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
data: typing.Dict = None): | ||
record = await self.get_record(chat=chat, user=user) | ||
await self.set_record(chat=chat, user=user, state=record['state'], data=data) | ||
|
||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, | ||
data: typing.Dict = None, **kwargs): | ||
data = await self.get_data(chat=chat, user=user) | ||
if data is None: | ||
data = [] | ||
data.update(data, **kwargs) | ||
await self.set_data(chat=chat, user=user, data=data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters