Skip to content

Commit

Permalink
Migrate from aioredis to redis.asyncio (#1074)
Browse files Browse the repository at this point in the history
* chore: migrate from aioredis to redis.asyncio

* chore: add tests for RedisStorage2
  • Loading branch information
Olegt0rr committed Dec 4, 2022
1 parent 87c0458 commit ae53429
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 288 deletions.
234 changes: 35 additions & 199 deletions aiogram/contrib/fsm_storage/redis.py
@@ -1,18 +1,18 @@
"""
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver
This module has redis storage for finite-state machine based on `redis <https://pypi.org/project/redis/>`_ driver.
"""

import asyncio
import logging
import typing
from abc import ABC, abstractmethod

import aioredis

from ...dispatcher.storage import BaseStorage
from ...utils import json
from ...utils.deprecated import deprecated

if typing.TYPE_CHECKING:
import aioredis

STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
Expand Down Expand Up @@ -67,6 +67,8 @@ async def redis(self) -> "aioredis.RedisConnection":
Get Redis connection
"""
# Use thread-safe asyncio Lock because this method without that is not safe
import aioredis

async with self._connection_lock:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_connection((self._host, self._port),
Expand Down Expand Up @@ -207,138 +209,6 @@ async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)


class AioRedisAdapterBase(ABC):
"""Base aioredis adapter class."""

def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: typing.Optional[int] = None,
password: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: typing.Optional[int] = None,
data_ttl: typing.Optional[int] = None,
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._kwargs = kwargs
self._prefix = (prefix,)

self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl

self._redis: typing.Optional["aioredis.Redis"] = None
self._connection_lock = asyncio.Lock()

@abstractmethod
async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
pass

async def close(self):
"""Grace shutdown."""
pass

async def wait_closed(self):
"""Wait for grace shutdown finishes."""
pass

async def set(self, name, value, ex=None, **kwargs):
"""Set the value at key ``name`` to ``value``."""
if ex == 0:
ex = None
return await self._redis.set(name, value, ex=ex, **kwargs)

async def get(self, name, **kwargs):
"""Return the value at key ``name`` or None."""
return await self._redis.get(name, **kwargs)

async def delete(self, *names):
"""Delete one or more keys specified by ``names``"""
return await self._redis.delete(*names)

async def keys(self, pattern, **kwargs):
"""Returns a list of keys matching ``pattern``."""
return await self._redis.keys(pattern, **kwargs)

async def flushdb(self):
"""Delete all keys in the current database."""
return await self._redis.flushdb()


class AioRedisAdapterV1(AioRedisAdapterBase):
"""Redis adapter for aioredis v1."""

async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_redis_pool(
(self._host, self._port),
db=self._db,
password=self._password,
ssl=self._ssl,
minsize=1,
maxsize=self._pool_size,
**self._kwargs,
)
return self._redis

async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()

async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True

async def get(self, name, **kwargs):
return await self._redis.get(name, encoding="utf8", **kwargs)

async def set(self, name, value, ex=None, **kwargs):
if ex == 0:
ex = None
return await self._redis.set(name, value, expire=ex, **kwargs)

async def keys(self, pattern, **kwargs):
"""Returns a list of keys matching ``pattern``."""
return await self._redis.keys(pattern, encoding="utf8", **kwargs)


class AioRedisAdapterV2(AioRedisAdapterBase):
"""Redis adapter for aioredis v2."""

async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None:
self._redis = aioredis.Redis(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
max_connections=self._pool_size,
decode_responses=True,
**self._kwargs,
)
return self._redis


class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
Expand All @@ -356,7 +226,6 @@ class RedisStorage2(BaseStorage):
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""

Expand All @@ -375,75 +244,49 @@ def __init__(
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._kwargs = kwargs
self._prefix = (prefix,)
from redis.asyncio import Redis

self._redis: typing.Optional[Redis] = Redis(
host=host,
port=port,
db=db,
password=password,
ssl=ssl,
max_connections=pool_size,
decode_responses=True,
**kwargs,
)

self._prefix = (prefix,)
self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl

self._redis: typing.Optional[AioRedisAdapterBase] = None
self._connection_lock = asyncio.Lock()

@deprecated("This method will be removed in aiogram v3.0. "
"You should use your own instance of Redis.", stacklevel=3)
async def redis(self) -> aioredis.Redis:
adapter = await self._get_adapter()
return await adapter.get_redis()

async def _get_adapter(self) -> AioRedisAdapterBase:
"""Get adapter based on aioredis version."""
if self._redis is None:
redis_version = int(aioredis.__version__.split(".")[0])
connection_data = dict(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
pool_size=self._pool_size,
**self._kwargs,
)
if redis_version == 1:
self._redis = AioRedisAdapterV1(**connection_data)
elif redis_version == 2:
self._redis = AioRedisAdapterV2(**connection_data)
else:
raise RuntimeError(f"Unsupported aioredis version: {redis_version}")
await self._redis.get_redis()
async def redis(self) -> "aioredis.Redis":
return self._redis

def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))

async def close(self):
if self._redis:
return await self._redis.close()
await self._redis.close()

async def wait_closed(self):
if self._redis:
await self._redis.wait_closed()
self._redis = None
pass

async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self._get_adapter()
return await redis.get(key) or self.resolve_state(default)
return await self._redis.get(key) or self.resolve_state(default)

async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self._get_adapter()
raw_result = await redis.get(key)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
Expand All @@ -452,21 +295,19 @@ async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: ty
state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self._get_adapter()
if state is None:
await redis.delete(key)
await self._redis.delete(key)
else:
await redis.set(key, self.resolve_state(state), ex=self._state_ttl)
await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl)

async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self._get_adapter()
if data:
await redis.set(key, json.dumps(data), ex=self._data_ttl)
await self._redis.set(key, json.dumps(data), ex=self._data_ttl)
else:
await redis.delete(key)
await self._redis.delete(key)

async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
Expand All @@ -483,8 +324,7 @@ async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self._get_adapter()
raw_result = await redis.get(key)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
Expand All @@ -493,11 +333,10 @@ async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self._get_adapter()
if bucket:
await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
else:
await redis.delete(key)
await self._redis.delete(key)

async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
Expand All @@ -515,24 +354,21 @@ async def reset_all(self, full=True):
:param full: clean DB or clean only states
:return:
"""
redis = await self._get_adapter()

if full:
await redis.flushdb()
await self._redis.flushdb()
else:
keys = await redis.keys(self.generate_key('*'))
await redis.delete(*keys)
keys = await self._redis.keys(self.generate_key('*'))
await self._redis.delete(*keys)

async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
redis = await self._get_adapter()
result = []

keys = await redis.keys(self.generate_key('*', '*', STATE_KEY))
keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY))
for item in keys:
*_, chat, user, _ = item.split(':')
result.append((chat, user))
Expand Down

0 comments on commit ae53429

Please sign in to comment.