diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 87d76374a3..da0b20d84c 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -1,18 +1,18 @@ """ -This module has redis storage for finite-state machine based on `aioredis `_ driver +This module has redis storage for finite-state machine based on `redis `_ driver. """ import asyncio import logging import typing -from abc import ABC, abstractmethod - -import aioredis from ...dispatcher.storage import BaseStorage from ...utils import json from ...utils.deprecated import deprecated +if typing.TYPE_CHECKING: + import aioredis + STATE_KEY = 'state' STATE_DATA_KEY = 'data' STATE_BUCKET_KEY = 'bucket' @@ -67,6 +67,8 @@ async def redis(self) -> "aioredis.RedisConnection": Get Redis connection """ # Use thread-safe asyncio Lock because this method without that is not safe + import aioredis + async with self._connection_lock: if self._redis is None or self._redis.closed: self._redis = await aioredis.create_connection((self._host, self._port), @@ -207,138 +209,6 @@ async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket) -class AioRedisAdapterBase(ABC): - """Base aioredis adapter class.""" - - def __init__( - self, - host: str = "localhost", - port: int = 6379, - db: typing.Optional[int] = None, - password: typing.Optional[str] = None, - ssl: typing.Optional[bool] = None, - pool_size: int = 10, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - prefix: str = "fsm", - state_ttl: typing.Optional[int] = None, - data_ttl: typing.Optional[int] = None, - bucket_ttl: typing.Optional[int] = None, - **kwargs, - ): - self._host = host - self._port = port - self._db = db - self._password = password - self._ssl = ssl - self._pool_size = pool_size - self._kwargs = kwargs - self._prefix = (prefix,) - - self._state_ttl = state_ttl - self._data_ttl = data_ttl - self._bucket_ttl = bucket_ttl - - self._redis: typing.Optional["aioredis.Redis"] = None - self._connection_lock = asyncio.Lock() - - @abstractmethod - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - pass - - async def close(self): - """Grace shutdown.""" - pass - - async def wait_closed(self): - """Wait for grace shutdown finishes.""" - pass - - async def set(self, name, value, ex=None, **kwargs): - """Set the value at key ``name`` to ``value``.""" - if ex == 0: - ex = None - return await self._redis.set(name, value, ex=ex, **kwargs) - - async def get(self, name, **kwargs): - """Return the value at key ``name`` or None.""" - return await self._redis.get(name, **kwargs) - - async def delete(self, *names): - """Delete one or more keys specified by ``names``""" - return await self._redis.delete(*names) - - async def keys(self, pattern, **kwargs): - """Returns a list of keys matching ``pattern``.""" - return await self._redis.keys(pattern, **kwargs) - - async def flushdb(self): - """Delete all keys in the current database.""" - return await self._redis.flushdb() - - -class AioRedisAdapterV1(AioRedisAdapterBase): - """Redis adapter for aioredis v1.""" - - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - async with self._connection_lock: # to prevent race - if self._redis is None or self._redis.closed: - self._redis = await aioredis.create_redis_pool( - (self._host, self._port), - db=self._db, - password=self._password, - ssl=self._ssl, - minsize=1, - maxsize=self._pool_size, - **self._kwargs, - ) - return self._redis - - async def close(self): - async with self._connection_lock: - if self._redis and not self._redis.closed: - self._redis.close() - - async def wait_closed(self): - async with self._connection_lock: - if self._redis: - return await self._redis.wait_closed() - return True - - async def get(self, name, **kwargs): - return await self._redis.get(name, encoding="utf8", **kwargs) - - async def set(self, name, value, ex=None, **kwargs): - if ex == 0: - ex = None - return await self._redis.set(name, value, expire=ex, **kwargs) - - async def keys(self, pattern, **kwargs): - """Returns a list of keys matching ``pattern``.""" - return await self._redis.keys(pattern, encoding="utf8", **kwargs) - - -class AioRedisAdapterV2(AioRedisAdapterBase): - """Redis adapter for aioredis v2.""" - - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - async with self._connection_lock: # to prevent race - if self._redis is None: - self._redis = aioredis.Redis( - host=self._host, - port=self._port, - db=self._db, - password=self._password, - ssl=self._ssl, - max_connections=self._pool_size, - decode_responses=True, - **self._kwargs, - ) - return self._redis - - class RedisStorage2(BaseStorage): """ Busted Redis-base storage for FSM. @@ -356,7 +226,6 @@ class RedisStorage2(BaseStorage): .. code-block:: python3 await dp.storage.close() - await dp.storage.wait_closed() """ @@ -375,75 +244,49 @@ def __init__( bucket_ttl: typing.Optional[int] = None, **kwargs, ): - self._host = host - self._port = port - self._db = db - self._password = password - self._ssl = ssl - self._pool_size = pool_size - self._kwargs = kwargs - self._prefix = (prefix,) + from redis.asyncio import Redis + + self._redis: typing.Optional[Redis] = Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl, + max_connections=pool_size, + decode_responses=True, + **kwargs, + ) + self._prefix = (prefix,) self._state_ttl = state_ttl self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: typing.Optional[AioRedisAdapterBase] = None - self._connection_lock = asyncio.Lock() - @deprecated("This method will be removed in aiogram v3.0. " "You should use your own instance of Redis.", stacklevel=3) - async def redis(self) -> aioredis.Redis: - adapter = await self._get_adapter() - return await adapter.get_redis() - - async def _get_adapter(self) -> AioRedisAdapterBase: - """Get adapter based on aioredis version.""" - if self._redis is None: - redis_version = int(aioredis.__version__.split(".")[0]) - connection_data = dict( - host=self._host, - port=self._port, - db=self._db, - password=self._password, - ssl=self._ssl, - pool_size=self._pool_size, - **self._kwargs, - ) - if redis_version == 1: - self._redis = AioRedisAdapterV1(**connection_data) - elif redis_version == 2: - self._redis = AioRedisAdapterV2(**connection_data) - else: - raise RuntimeError(f"Unsupported aioredis version: {redis_version}") - await self._redis.get_redis() + async def redis(self) -> "aioredis.Redis": return self._redis def generate_key(self, *parts): return ':'.join(self._prefix + tuple(map(str, parts))) async def close(self): - if self._redis: - return await self._redis.close() + await self._redis.close() async def wait_closed(self): - if self._redis: - await self._redis.wait_closed() - self._redis = None + pass async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self._get_adapter() - return await redis.get(key) or self.resolve_state(default) + return await self._redis.get(key) or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self._get_adapter() - raw_result = await redis.get(key) + raw_result = await self._redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -452,21 +295,19 @@ async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: ty state: typing.Optional[typing.AnyStr] = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self._get_adapter() if state is None: - await redis.delete(key) + await self._redis.delete(key) else: - await redis.set(key, self.resolve_state(state), ex=self._state_ttl) + await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self._get_adapter() if data: - await redis.set(key, json.dumps(data), ex=self._data_ttl) + await self._redis.set(key, json.dumps(data), ex=self._data_ttl) else: - await redis.delete(key) + await self._redis.delete(key) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): @@ -483,8 +324,7 @@ async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self._get_adapter() - raw_result = await redis.get(key) + raw_result = await self._redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -493,11 +333,10 @@ async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t bucket: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self._get_adapter() if bucket: - await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) + await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) else: - await redis.delete(key) + await self._redis.delete(key) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, @@ -515,13 +354,11 @@ async def reset_all(self, full=True): :param full: clean DB or clean only states :return: """ - redis = await self._get_adapter() - if full: - await redis.flushdb() + await self._redis.flushdb() else: - keys = await redis.keys(self.generate_key('*')) - await redis.delete(*keys) + keys = await self._redis.keys(self.generate_key('*')) + await self._redis.delete(*keys) async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ @@ -529,10 +366,9 @@ async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: :return: list of tuples where first element is chat id and second is user id """ - redis = await self._get_adapter() result = [] - keys = await redis.keys(self.generate_key('*', '*', STATE_KEY)) + keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY)) for item in keys: *_, chat, user, _ = item.split(':') result.append((chat, user)) diff --git a/tests/contrib/fsm_storage/test_storage.py b/tests/contrib/fsm_storage/test_storage.py deleted file mode 100644 index ae06025c6d..0000000000 --- a/tests/contrib/fsm_storage/test_storage.py +++ /dev/null @@ -1,89 +0,0 @@ -import aioredis -import pytest -import pytest_asyncio -from pytest_lazyfixture import lazy_fixture - -from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 - - -@pytest_asyncio.fixture() -@pytest.mark.redis -async def redis_store(redis_options): - if int(aioredis.__version__.split(".")[0]) == 2: - pytest.skip('aioredis v2 is not supported.') - return - s = RedisStorage(**redis_options) - try: - yield s - finally: - conn = await s.redis() - await conn.execute('FLUSHDB') - await s.close() - await s.wait_closed() - - -@pytest_asyncio.fixture() -@pytest.mark.redis -async def redis_store2(redis_options): - s = RedisStorage2(**redis_options) - try: - yield s - finally: - conn = await s.redis() - await conn.flushdb() - await s.close() - await s.wait_closed() - - -@pytest_asyncio.fixture() -async def memory_store(): - yield MemoryStorage() - - -@pytest.mark.parametrize( - "store", [ - lazy_fixture('redis_store'), - lazy_fixture('redis_store2'), - lazy_fixture('memory_store'), - ] -) -class TestStorage: - @pytest.mark.asyncio - async def test_set_get(self, store): - assert await store.get_data(chat='1234') == {} - await store.set_data(chat='1234', data={'foo': 'bar'}) - assert await store.get_data(chat='1234') == {'foo': 'bar'} - - @pytest.mark.asyncio - async def test_reset(self, store): - await store.set_data(chat='1234', data={'foo': 'bar'}) - await store.reset_data(chat='1234') - assert await store.get_data(chat='1234') == {} - - @pytest.mark.asyncio - async def test_reset_empty(self, store): - await store.reset_data(chat='1234') - assert await store.get_data(chat='1234') == {} - - -@pytest.mark.parametrize( - "store", [ - lazy_fixture('redis_store'), - lazy_fixture('redis_store2'), - ] -) -class TestRedisStorage2: - @pytest.mark.asyncio - async def test_close_and_open_connection(self, store): - await store.set_data(chat='1234', data={'foo': 'bar'}) - assert await store.get_data(chat='1234') == {'foo': 'bar'} - pool_id = id(store._redis) - await store.close() - await store.wait_closed() - - # new pool will be open at this point - assert await store.get_data(chat='1234') == { - 'foo': 'bar', - } - assert id(store._redis) != pool_id diff --git a/tests/test_contrib/test_fsm_storage/test_storage.py b/tests/test_contrib/test_fsm_storage/test_storage.py new file mode 100644 index 0000000000..93295c9193 --- /dev/null +++ b/tests/test_contrib/test_fsm_storage/test_storage.py @@ -0,0 +1,160 @@ +import aioredis +import pytest +import pytest_asyncio +from pytest_lazyfixture import lazy_fixture +from redis.asyncio.connection import Connection, ConnectionPool + +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 +from aiogram.types import Chat, User +from tests.types.dataset import CHAT, USER + +pytestmark = pytest.mark.asyncio + +@pytest_asyncio.fixture() +@pytest.mark.redis +async def redis_store(redis_options): + if int(aioredis.__version__.split(".")[0]) == 2: + pytest.skip('aioredis v2 is not supported.') + return + s = RedisStorage(**redis_options) + try: + yield s + finally: + conn = await s.redis() + await conn.execute('FLUSHDB') + await s.close() + await s.wait_closed() + + +@pytest_asyncio.fixture() +@pytest.mark.redis +async def redis_store2(redis_options): + s = RedisStorage2(**redis_options) + try: + yield s + finally: + conn = await s.redis() + await conn.flushdb() + await s.close() + await s.wait_closed() + + +@pytest_asyncio.fixture() +async def memory_store(): + yield MemoryStorage() + + +@pytest.mark.parametrize( + "store", [ + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), + lazy_fixture('memory_store'), + ] +) +class TestStorage: + async def test_set_get(self, store): + assert await store.get_data(chat='1234') == {} + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + + async def test_reset(self, store): + await store.set_data(chat='1234', data={'foo': 'bar'}) + await store.reset_data(chat='1234') + assert await store.get_data(chat='1234') == {} + + async def test_reset_empty(self, store): + await store.reset_data(chat='1234') + assert await store.get_data(chat='1234') == {} + + +@pytest.mark.parametrize( + "store", [ + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), + ] +) +class TestRedisStorage2: + async def test_close_and_open_connection(self, store: RedisStorage2): + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + await store.close() + await store.wait_closed() + + pool: ConnectionPool = store._redis.connection_pool + + # noinspection PyUnresolvedReferences + assert not pool._in_use_connections + + # noinspection PyUnresolvedReferences + if pool._available_connections: + # noinspection PyUnresolvedReferences + connection: Connection = pool._available_connections[0] + assert connection.is_connected is False + + @pytest.mark.parametrize( + "chat_id,user_id,state", + [ + [12345, 54321, "foo"], + [12345, 54321, None], + [12345, None, "foo"], + [None, 54321, "foo"], + ], + ) + async def test_set_get_state(self, chat_id, user_id, state, store): + await store.reset_state(chat=chat_id, user=user_id, with_data=False) + + await store.set_state(chat=chat_id, user=user_id, state=state) + s = await store.get_state(chat=chat_id, user=user_id) + assert s == state + + @pytest.mark.parametrize( + "chat_id,user_id,data,new_data", + [ + [12345, 54321, {"foo": "bar"}, {"bar": "foo"}], + [12345, 54321, None, None], + [12345, 54321, {"foo": "bar"}, None], + [12345, 54321, None, {"bar": "foo"}], + [12345, None, {"foo": "bar"}, {"bar": "foo"}], + [None, 54321, {"foo": "bar"}, {"bar": "foo"}], + ], + ) + async def test_set_get_update_data(self, chat_id, user_id, data, new_data, store): + await store.reset_state(chat=chat_id, user=user_id, with_data=True) + + await store.set_data(chat=chat_id, user=user_id, data=data) + d = await store.get_data(chat=chat_id, user=user_id) + assert d == (data or {}) + + await store.update_data(chat=chat_id, user=user_id, data=new_data) + d = await store.get_data(chat=chat_id, user=user_id) + updated_data = (data or {}) + updated_data.update(new_data or {}) + assert d == updated_data + + async def test_has_bucket(self, store): + assert store.has_bucket() + + @pytest.mark.parametrize( + "chat_id,user_id,data,new_data", + [ + [12345, 54321, {"foo": "bar"}, {"bar": "foo"}], + [12345, 54321, None, None], + [12345, 54321, {"foo": "bar"}, None], + [12345, 54321, None, {"bar": "foo"}], + [12345, None, {"foo": "bar"}, {"bar": "foo"}], + [None, 54321, {"foo": "bar"}, {"bar": "foo"}], + ], + ) + async def test_set_get_update_bucket(self, chat_id, user_id, data, new_data, store): + await store.reset_state(chat=chat_id, user=user_id, with_data=True) + + await store.set_bucket(chat=chat_id, user=user_id, bucket=data) + d = await store.get_bucket(chat=chat_id, user=user_id) + assert d == (data or {}) + + await store.update_bucket(chat=chat_id, user=user_id, bucket=new_data) + d = await store.get_bucket(chat=chat_id, user=user_id) + updated_bucket = (data or {}) + updated_bucket.update(new_data or {}) + assert d == updated_bucket