Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Always reference these instructions first and fallback to search or bash command
- `pip install -e .` -- installs the package in development mode. NEVER CANCEL: Takes 10-60 seconds, may timeout due to network issues. Set timeout to 120+ seconds (extra margin for slow mirrors or network issues).
- `pip install coverage pre-commit pytest pytest-cov pytest-dotenv ruff` -- installs development dependencies. NEVER CANCEL: Takes 30-120 seconds. Set timeout to 180+ seconds (extra margin for slow mirrors or network issues).
- `python -m pytest tests/ -v` -- runs unit tests (takes ~0.4 seconds, 20 passed, 5 skipped without database)
- `ruff check .` -- runs linting (takes ~0.01 seconds)
- `ruff check .` -- runs linting (takes ~0.01 seconds)
- `ruff format --check .` -- checks code formatting (takes ~0.01 seconds)

### Environment Configuration:
Expand Down Expand Up @@ -152,4 +152,4 @@ mem-db-utils/
- Check if database container is running: `docker ps`
- Test connection manually: `docker exec -it test-redis redis-cli ping`
- Verify port availability: `netstat -tlnp | grep 6379`
- Check firewall settings if running on remote host
- Check firewall settings if running on remote host
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dev = [
"coverage>=7.10.2",
"pre-commit>=4.2.0",
"pytest>=8.4.1",
"pytest-asyncio>=1.2.0",
"pytest-cov>=6.2.1",
"pytest-dotenv>=0.5.2",
]
Expand Down
58 changes: 58 additions & 0 deletions src/mem_db_utils/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from urllib.parse import urlparse

import redis.asyncio as aioredis

from mem_db_utils.config import DBConfig, DBType


class MemDBConnector:
__slots__ = ("uri", "db_type", "connection_type", "service")

def __init__(self, redis_type: str = None, master_service: str = None):
self.uri = DBConfig.db_url
self.db_type = DBConfig.db_type
self.service = None
self.connection_type = None
if self.db_type == DBType.REDIS:
self.connection_type = redis_type or DBConfig.redis_connection_type
self.service = master_service or DBConfig.redis_master_service

async def connect(self, db: int = 0, **kwargs):
"""
The async connect function is used to connect to a MemDB instance asynchronously.

:param self: Represent the instance of the class
:param db: int: Specify the database number to connect to
:return: An async connection object
"""
if self.connection_type == "sentinel":
return await self._sentinel(db=db, **kwargs)
return await aioredis.from_url(url=self.uri, db=db, decode_responses=kwargs.get("decode_response", True))

async def _sentinel(self, db: int, **kwargs):
"""
The async _sentinel function is used to connect to a Redis Sentinel service asynchronously.

:param self: Bind the method to an instance of the class
:param db: int: Select the database to connect to
:return: An async connection object
"""
parsed_uri = urlparse(self.uri)
sentinel_host = parsed_uri.hostname
sentinel_port = parsed_uri.port
redis_password = parsed_uri.password
sentinel_hosts = [(sentinel_host, sentinel_port)]

sentinel = aioredis.Sentinel(
sentinel_hosts,
socket_timeout=kwargs.get("timeout", DBConfig.db_timeout),
password=redis_password,
)

# Connect to the Redis Sentinel master service and select the specified database
connection_object = sentinel.master_for(self.service, decode_responses=kwargs.get("decode_response", True))
await connection_object.select(db)
return connection_object


__all__ = ["MemDBConnector"]
182 changes: 182 additions & 0 deletions tests/test_async_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Tests for async MemDBConnector class."""

from unittest.mock import AsyncMock, patch
from urllib.parse import urlparse

import pytest

from mem_db_utils.asyncio import MemDBConnector as AsyncMemDBConnector
from mem_db_utils.config import DBConfig, DBType


class TestAsyncMemDBConnector:
"""Test the AsyncMemDBConnector class."""

def test_init_with_defaults(self):
"""Test initialization with default values from .env file."""
connector = AsyncMemDBConnector()
assert connector.uri == DBConfig.db_url
assert connector.db_type == DBConfig.db_type

# For Redis databases
if connector.db_type == DBType.REDIS:
assert connector.connection_type == DBConfig.redis_connection_type
assert connector.service == DBConfig.redis_master_service
else:
# For non-Redis databases, these should be None
assert connector.connection_type is None
assert connector.service is None

def test_init_with_redis_type_override(self):
"""Test initialization with Redis connection type override."""
connector = AsyncMemDBConnector(redis_type="sentinel")
assert connector.uri == DBConfig.db_url
assert connector.db_type == DBConfig.db_type

if connector.db_type == DBType.REDIS:
assert connector.connection_type == "sentinel"
assert connector.service == DBConfig.redis_master_service

def test_init_with_master_service_override(self):
"""Test initialization with master service override."""
connector = AsyncMemDBConnector(master_service="custom_master")
assert connector.uri == DBConfig.db_url
assert connector.db_type == DBConfig.db_type

if connector.db_type == DBType.REDIS:
assert connector.connection_type == DBConfig.redis_connection_type
assert connector.service == "custom_master"

@pytest.mark.asyncio
@patch("redis.asyncio.from_url")
async def test_connect_direct_connection(self, mock_from_url):
"""Test direct async database connection."""
mock_connection = AsyncMock()

# Mock from_url to return a coroutine
async def mock_coro():
return mock_connection

mock_from_url.return_value = mock_coro()

connector = AsyncMemDBConnector()
# Only test direct connection if not using sentinel
if connector.connection_type != "sentinel":
result = await connector.connect(db=1)

mock_from_url.assert_called_once_with(url=DBConfig.db_url, db=1, decode_responses=True)
assert result == mock_connection

@pytest.mark.asyncio
@patch("redis.asyncio.from_url")
async def test_connect_with_custom_kwargs(self, mock_from_url):
"""Test async connection with custom keyword arguments."""
mock_connection = AsyncMock()

# Mock from_url to return a coroutine
async def mock_coro():
return mock_connection

mock_from_url.return_value = mock_coro()

connector = AsyncMemDBConnector()
if connector.connection_type != "sentinel":
result = await connector.connect(db=2, decode_response=False)

mock_from_url.assert_called_once_with(url=DBConfig.db_url, db=2, decode_responses=False)
assert result == mock_connection

@pytest.mark.asyncio
@patch("redis.asyncio.Sentinel")
async def test_connect_sentinel(self, mock_sentinel_class):
"""Test async Redis Sentinel connection when configured."""
mock_sentinel = AsyncMock()
mock_master = AsyncMock()
mock_master.select = AsyncMock()
mock_sentinel.master_for.return_value = mock_master
mock_sentinel_class.return_value = mock_sentinel

connector = AsyncMemDBConnector()
if connector.connection_type == "sentinel" and connector.db_type == DBType.REDIS:
result = await connector.connect(db=3)

# Verify Sentinel was created with correct parameters
parsed_uri = urlparse(DBConfig.db_url)
expected_hosts = [(parsed_uri.hostname, parsed_uri.port)]

mock_sentinel_class.assert_called_once_with(
expected_hosts, socket_timeout=DBConfig.db_timeout, password=parsed_uri.password
)

# Verify master connection was requested
mock_sentinel.master_for.assert_called_once_with(DBConfig.redis_master_service, decode_responses=True)

# Verify database selection
mock_master.select.assert_called_once_with(3)
assert result == mock_master

@pytest.mark.asyncio
@patch("redis.asyncio.from_url")
async def test_connect_default_db(self, mock_from_url):
"""Test async connection with default database (0)."""
mock_connection = AsyncMock()

# Mock from_url to return a coroutine
async def mock_coro():
return mock_connection

mock_from_url.return_value = mock_coro()

connector = AsyncMemDBConnector()
if connector.connection_type != "sentinel":
result = await connector.connect() # No db parameter

mock_from_url.assert_called_once_with(
url=DBConfig.db_url,
db=0, # Default value
decode_responses=True,
)
assert result == mock_connection

def test_slots_attribute(self):
"""Test that the class uses __slots__ for memory efficiency."""
connector = AsyncMemDBConnector()

# Check that __slots__ is defined
assert hasattr(AsyncMemDBConnector, "__slots__")
expected_slots = ("uri", "db_type", "connection_type", "service")
assert AsyncMemDBConnector.__slots__ == expected_slots

# Verify we can't add arbitrary attributes
with pytest.raises(AttributeError):
connector.new_attribute = "test"

def test_non_redis_db_type_behavior(self):
"""Test async connector behavior with non-Redis database types."""
connector = AsyncMemDBConnector()
if connector.db_type != DBType.REDIS:
# For non-Redis databases, connection_type and service should be None
assert connector.connection_type is None
assert connector.service is None

@pytest.mark.asyncio
async def test_error_handling_in_connect(self):
"""Test error handling in async connect method."""
with patch("redis.asyncio.from_url") as mock_from_url:
mock_from_url.side_effect = Exception("Connection failed")

connector = AsyncMemDBConnector()
if connector.connection_type != "sentinel":
with pytest.raises(Exception, match="Connection failed"):
await connector.connect()

@pytest.mark.asyncio
async def test_error_handling_in_sentinel(self):
"""Test error handling in async sentinel method."""
with patch("redis.asyncio.Sentinel") as mock_sentinel_class:
mock_sentinel_class.side_effect = Exception("Sentinel connection failed")

connector = AsyncMemDBConnector(redis_type="sentinel")
if connector.db_type == DBType.REDIS:
with pytest.raises(Exception, match="Sentinel connection failed"):
await connector.connect()
Loading