diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 72d43e8..5a885a0 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -10,7 +10,7 @@ Always reference these instructions first and fallback to search or bash command - `pip install -e .` -- installs the package in development mode. NEVER CANCEL: Takes 10-60 seconds, may timeout due to network issues. Set timeout to 120+ seconds (extra margin for slow mirrors or network issues). - `pip install coverage pre-commit pytest pytest-cov pytest-dotenv ruff` -- installs development dependencies. NEVER CANCEL: Takes 30-120 seconds. Set timeout to 180+ seconds (extra margin for slow mirrors or network issues). - `python -m pytest tests/ -v` -- runs unit tests (takes ~0.4 seconds, 20 passed, 5 skipped without database) -- `ruff check .` -- runs linting (takes ~0.01 seconds) +- `ruff check .` -- runs linting (takes ~0.01 seconds) - `ruff format --check .` -- checks code formatting (takes ~0.01 seconds) ### Environment Configuration: @@ -152,4 +152,4 @@ mem-db-utils/ - Check if database container is running: `docker ps` - Test connection manually: `docker exec -it test-redis redis-cli ping` - Verify port availability: `netstat -tlnp | grep 6379` -- Check firewall settings if running on remote host \ No newline at end of file +- Check firewall settings if running on remote host diff --git a/pyproject.toml b/pyproject.toml index 719966a..15cef48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "coverage>=7.10.2", "pre-commit>=4.2.0", "pytest>=8.4.1", + "pytest-asyncio>=1.2.0", "pytest-cov>=6.2.1", "pytest-dotenv>=0.5.2", ] diff --git a/src/mem_db_utils/asyncio/__init__.py b/src/mem_db_utils/asyncio/__init__.py new file mode 100644 index 0000000..792c15e --- /dev/null +++ b/src/mem_db_utils/asyncio/__init__.py @@ -0,0 +1,58 @@ +from urllib.parse import urlparse + +import redis.asyncio as aioredis + +from mem_db_utils.config import DBConfig, DBType + + +class MemDBConnector: + __slots__ = ("uri", "db_type", "connection_type", "service") + + def __init__(self, redis_type: str = None, master_service: str = None): + self.uri = DBConfig.db_url + self.db_type = DBConfig.db_type + self.service = None + self.connection_type = None + if self.db_type == DBType.REDIS: + self.connection_type = redis_type or DBConfig.redis_connection_type + self.service = master_service or DBConfig.redis_master_service + + async def connect(self, db: int = 0, **kwargs): + """ + The async connect function is used to connect to a MemDB instance asynchronously. + + :param self: Represent the instance of the class + :param db: int: Specify the database number to connect to + :return: An async connection object + """ + if self.connection_type == "sentinel": + return await self._sentinel(db=db, **kwargs) + return await aioredis.from_url(url=self.uri, db=db, decode_responses=kwargs.get("decode_response", True)) + + async def _sentinel(self, db: int, **kwargs): + """ + The async _sentinel function is used to connect to a Redis Sentinel service asynchronously. + + :param self: Bind the method to an instance of the class + :param db: int: Select the database to connect to + :return: An async connection object + """ + parsed_uri = urlparse(self.uri) + sentinel_host = parsed_uri.hostname + sentinel_port = parsed_uri.port + redis_password = parsed_uri.password + sentinel_hosts = [(sentinel_host, sentinel_port)] + + sentinel = aioredis.Sentinel( + sentinel_hosts, + socket_timeout=kwargs.get("timeout", DBConfig.db_timeout), + password=redis_password, + ) + + # Connect to the Redis Sentinel master service and select the specified database + connection_object = sentinel.master_for(self.service, decode_responses=kwargs.get("decode_response", True)) + await connection_object.select(db) + return connection_object + + +__all__ = ["MemDBConnector"] diff --git a/tests/test_async_connector.py b/tests/test_async_connector.py new file mode 100644 index 0000000..1679918 --- /dev/null +++ b/tests/test_async_connector.py @@ -0,0 +1,182 @@ +"""Tests for async MemDBConnector class.""" + +from unittest.mock import AsyncMock, patch +from urllib.parse import urlparse + +import pytest + +from mem_db_utils.asyncio import MemDBConnector as AsyncMemDBConnector +from mem_db_utils.config import DBConfig, DBType + + +class TestAsyncMemDBConnector: + """Test the AsyncMemDBConnector class.""" + + def test_init_with_defaults(self): + """Test initialization with default values from .env file.""" + connector = AsyncMemDBConnector() + assert connector.uri == DBConfig.db_url + assert connector.db_type == DBConfig.db_type + + # For Redis databases + if connector.db_type == DBType.REDIS: + assert connector.connection_type == DBConfig.redis_connection_type + assert connector.service == DBConfig.redis_master_service + else: + # For non-Redis databases, these should be None + assert connector.connection_type is None + assert connector.service is None + + def test_init_with_redis_type_override(self): + """Test initialization with Redis connection type override.""" + connector = AsyncMemDBConnector(redis_type="sentinel") + assert connector.uri == DBConfig.db_url + assert connector.db_type == DBConfig.db_type + + if connector.db_type == DBType.REDIS: + assert connector.connection_type == "sentinel" + assert connector.service == DBConfig.redis_master_service + + def test_init_with_master_service_override(self): + """Test initialization with master service override.""" + connector = AsyncMemDBConnector(master_service="custom_master") + assert connector.uri == DBConfig.db_url + assert connector.db_type == DBConfig.db_type + + if connector.db_type == DBType.REDIS: + assert connector.connection_type == DBConfig.redis_connection_type + assert connector.service == "custom_master" + + @pytest.mark.asyncio + @patch("redis.asyncio.from_url") + async def test_connect_direct_connection(self, mock_from_url): + """Test direct async database connection.""" + mock_connection = AsyncMock() + + # Mock from_url to return a coroutine + async def mock_coro(): + return mock_connection + + mock_from_url.return_value = mock_coro() + + connector = AsyncMemDBConnector() + # Only test direct connection if not using sentinel + if connector.connection_type != "sentinel": + result = await connector.connect(db=1) + + mock_from_url.assert_called_once_with(url=DBConfig.db_url, db=1, decode_responses=True) + assert result == mock_connection + + @pytest.mark.asyncio + @patch("redis.asyncio.from_url") + async def test_connect_with_custom_kwargs(self, mock_from_url): + """Test async connection with custom keyword arguments.""" + mock_connection = AsyncMock() + + # Mock from_url to return a coroutine + async def mock_coro(): + return mock_connection + + mock_from_url.return_value = mock_coro() + + connector = AsyncMemDBConnector() + if connector.connection_type != "sentinel": + result = await connector.connect(db=2, decode_response=False) + + mock_from_url.assert_called_once_with(url=DBConfig.db_url, db=2, decode_responses=False) + assert result == mock_connection + + @pytest.mark.asyncio + @patch("redis.asyncio.Sentinel") + async def test_connect_sentinel(self, mock_sentinel_class): + """Test async Redis Sentinel connection when configured.""" + mock_sentinel = AsyncMock() + mock_master = AsyncMock() + mock_master.select = AsyncMock() + mock_sentinel.master_for.return_value = mock_master + mock_sentinel_class.return_value = mock_sentinel + + connector = AsyncMemDBConnector() + if connector.connection_type == "sentinel" and connector.db_type == DBType.REDIS: + result = await connector.connect(db=3) + + # Verify Sentinel was created with correct parameters + parsed_uri = urlparse(DBConfig.db_url) + expected_hosts = [(parsed_uri.hostname, parsed_uri.port)] + + mock_sentinel_class.assert_called_once_with( + expected_hosts, socket_timeout=DBConfig.db_timeout, password=parsed_uri.password + ) + + # Verify master connection was requested + mock_sentinel.master_for.assert_called_once_with(DBConfig.redis_master_service, decode_responses=True) + + # Verify database selection + mock_master.select.assert_called_once_with(3) + assert result == mock_master + + @pytest.mark.asyncio + @patch("redis.asyncio.from_url") + async def test_connect_default_db(self, mock_from_url): + """Test async connection with default database (0).""" + mock_connection = AsyncMock() + + # Mock from_url to return a coroutine + async def mock_coro(): + return mock_connection + + mock_from_url.return_value = mock_coro() + + connector = AsyncMemDBConnector() + if connector.connection_type != "sentinel": + result = await connector.connect() # No db parameter + + mock_from_url.assert_called_once_with( + url=DBConfig.db_url, + db=0, # Default value + decode_responses=True, + ) + assert result == mock_connection + + def test_slots_attribute(self): + """Test that the class uses __slots__ for memory efficiency.""" + connector = AsyncMemDBConnector() + + # Check that __slots__ is defined + assert hasattr(AsyncMemDBConnector, "__slots__") + expected_slots = ("uri", "db_type", "connection_type", "service") + assert AsyncMemDBConnector.__slots__ == expected_slots + + # Verify we can't add arbitrary attributes + with pytest.raises(AttributeError): + connector.new_attribute = "test" + + def test_non_redis_db_type_behavior(self): + """Test async connector behavior with non-Redis database types.""" + connector = AsyncMemDBConnector() + if connector.db_type != DBType.REDIS: + # For non-Redis databases, connection_type and service should be None + assert connector.connection_type is None + assert connector.service is None + + @pytest.mark.asyncio + async def test_error_handling_in_connect(self): + """Test error handling in async connect method.""" + with patch("redis.asyncio.from_url") as mock_from_url: + mock_from_url.side_effect = Exception("Connection failed") + + connector = AsyncMemDBConnector() + if connector.connection_type != "sentinel": + with pytest.raises(Exception, match="Connection failed"): + await connector.connect() + + @pytest.mark.asyncio + async def test_error_handling_in_sentinel(self): + """Test error handling in async sentinel method.""" + with patch("redis.asyncio.Sentinel") as mock_sentinel_class: + mock_sentinel_class.side_effect = Exception("Sentinel connection failed") + + connector = AsyncMemDBConnector(redis_type="sentinel") + if connector.db_type == DBType.REDIS: + with pytest.raises(Exception, match="Sentinel connection failed"): + await connector.connect() diff --git a/tests/test_async_integration.py b/tests/test_async_integration.py new file mode 100644 index 0000000..1e3ee96 --- /dev/null +++ b/tests/test_async_integration.py @@ -0,0 +1,173 @@ +"""Integration tests for async MemDBConnector with real database connections. + +These tests use the database configuration from the .env file. +The tests will work with Redis databases only as async support is primarily for Redis. +""" + +import asyncio + +import pytest + +from mem_db_utils.asyncio import MemDBConnector as AsyncMemDBConnector +from mem_db_utils.config import DBConfig, DBType + + +class TestAsyncMemDBConnectorIntegration: + """Integration tests with real async database connections.""" + + @pytest.mark.asyncio + async def test_async_database_connection(self): + """Test async connection to the configured database.""" + connector = AsyncMemDBConnector() + + try: + conn = await connector.connect(db=0) + # Test basic operations - works for Redis-compatible databases + if connector.db_type in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + result = await conn.ping() + assert result is True + + # Test set/get operations + await conn.set("test_async_key", "test_async_value") + value = await conn.get("test_async_key") + assert value == "test_async_value" + + # Cleanup + await conn.delete("test_async_key") + + # Close the connection + await conn.aclose() + else: + # For other database types, just verify connection exists + assert conn is not None + + except Exception as e: + pytest.skip(f"Database not available at {DBConfig.db_url}: {e}") + + @pytest.mark.asyncio + async def test_async_database_with_different_db_number(self): + """Test async connection with different database number (Redis-compatible only).""" + connector = AsyncMemDBConnector() + + # Only test for Redis-compatible databases that support db selection + if connector.db_type not in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + pytest.skip(f"Database selection not supported for {connector.db_type}") + + try: + conn = await connector.connect(db=1) + result = await conn.ping() + assert result is True + await conn.aclose() + + except Exception as e: + pytest.skip(f"Database not available at {DBConfig.db_url}: {e}") + + @pytest.mark.asyncio + async def test_async_auto_type_detection(self): + """Test automatic type detection from URL in async context.""" + connector = AsyncMemDBConnector() + assert connector.db_type == DBConfig.db_type + assert connector.uri == DBConfig.db_url + + @pytest.mark.asyncio + async def test_async_connection_with_decode_responses_false(self): + """Test async connection with decode_responses=False (Redis-compatible only).""" + connector = AsyncMemDBConnector() + + # Only test for Redis-compatible databases + if connector.db_type not in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + pytest.skip(f"decode_responses not supported for {connector.db_type}") + + try: + conn = await connector.connect(db=0, decode_response=False) + result = await conn.ping() + assert result is True + await conn.aclose() + + except Exception as e: + pytest.skip(f"Database not available at {DBConfig.db_url}: {e}") + + @pytest.mark.asyncio + async def test_async_connection_timeout_configuration(self): + """Test async connection with configured timeout.""" + connector = AsyncMemDBConnector() + + try: + # This should work with the configured timeout + conn = await connector.connect(db=0) + if connector.db_type in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + result = await conn.ping() + assert result is True + await conn.aclose() + else: + assert conn is not None + + except Exception as e: + pytest.skip(f"Database not available at {DBConfig.db_url}: {e}") + + @pytest.mark.skip(reason="Requires Redis Sentinel setup") + @pytest.mark.asyncio + async def test_async_sentinel_connection_integration(self): + """Test async Redis Sentinel connection (requires Sentinel setup).""" + connector = AsyncMemDBConnector() + + if connector.db_type != DBType.REDIS or connector.connection_type != "sentinel": + pytest.skip("Test requires Redis Sentinel configuration") + + try: + conn = await connector.connect(db=0) + result = await conn.ping() + assert result is True + await conn.aclose() + + except Exception as e: + pytest.skip(f"Redis Sentinel not available: {e}") + + @pytest.mark.asyncio + async def test_async_error_handling_with_invalid_db_number(self): + """Test async error handling with invalid database number.""" + connector = AsyncMemDBConnector() + + # Only test for Redis-compatible databases that support db selection + if connector.db_type not in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + pytest.skip(f"Database selection not supported for {connector.db_type}") + + try: + # Try to connect to a very high database number that likely doesn't exist + with pytest.raises(Exception): # noqa + conn = await connector.connect(db=999) + await conn.ping() + except Exception as e: + pytest.skip(f"Database not available for error testing: {e}") + + @pytest.mark.asyncio + async def test_async_concurrent_connections(self): + """Test multiple concurrent async connections.""" + connector = AsyncMemDBConnector() + + if connector.db_type not in [DBType.REDIS, DBType.DRAGONFLY, DBType.VALKEY]: + pytest.skip(f"Concurrent connections test not supported for {connector.db_type}") + + async def test_connection(conn_id): + try: + conn = await connector.connect(db=0) + await conn.set(f"async_test_key_{conn_id}", f"value_{conn_id}") + value = await conn.get(f"async_test_key_{conn_id}") + assert value == f"value_{conn_id}" + await conn.delete(f"async_test_key_{conn_id}") + await conn.aclose() + return True + except Exception: + return False + + try: + # Test 3 concurrent connections + tasks = [test_connection(i) for i in range(3)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # At least some connections should succeed + success_count = sum(1 for r in results if r is True) + assert success_count > 0 + + except Exception as e: + pytest.skip(f"Database not available for concurrent testing: {e}") diff --git a/uv.lock b/uv.lock index 2bbd458..b3973c6 100644 --- a/uv.lock +++ b/uv.lock @@ -145,6 +145,7 @@ dev = [ { name = "coverage" }, { name = "pre-commit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-dotenv" }, ] @@ -162,6 +163,7 @@ dev = [ { name = "coverage", specifier = ">=7.10.2" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, { name = "pytest-dotenv", specifier = ">=0.5.2" }, ] @@ -314,6 +316,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.prismatica.in/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0"