Skip to content

Allows asyncio cluster mode connections to wait for free connection when at max. #3359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Adds capability for cluster mode to await free connection instead of raising.
* Move doctests (doc code examples) to main branch
* Update `ResponseT` type hint
* Allow to control the minimum SSL version
45 changes: 30 additions & 15 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
import collections
import random
import socket
import ssl
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
List,
@@ -65,9 +63,9 @@
RedisClusterException,
ResponseError,
SlotNotCoveredError,
TimeoutError,
TryAgainError,
)
from redis.exceptions import TimeoutError as RedisTimeoutError
from redis.exceptions import TryAgainError
from redis.typing import AnyKeyT, EncodableT, KeyT
from redis.utils import (
deprecated_function,
@@ -264,6 +262,7 @@ def __init__(
socket_timeout: Optional[float] = None,
retry: Optional["Retry"] = None,
retry_on_error: Optional[List[Type[Exception]]] = None,
wait_for_connections: bool = False,
# SSL related kwargs
ssl: bool = False,
ssl_ca_certs: Optional[str] = None,
@@ -326,6 +325,7 @@ def __init__(
"socket_timeout": socket_timeout,
"retry": retry,
"protocol": protocol,
"wait_for_connections": wait_for_connections,
# Client cache related kwargs
"cache_enabled": cache_enabled,
"client_cache": client_cache,
@@ -364,7 +364,7 @@ def __init__(
)
if not retry_on_error:
# Default errors for retrying
retry_on_error = [ConnectionError, TimeoutError]
retry_on_error = [ConnectionError, RedisTimeoutError]
self.retry.update_supported_errors(retry_on_error)
kwargs.update({"retry": self.retry})

@@ -800,7 +800,7 @@ async def _execute_command(
return await target_node.execute_command(*args, **kwargs)
except (BusyLoadingError, MaxConnectionsError):
raise
except (ConnectionError, TimeoutError):
except (ConnectionError, RedisTimeoutError):
# Connection retries are being handled in the node's
# Retry object.
# Remove the failed node from the startup nodes before we try
@@ -962,6 +962,7 @@ class ClusterNode:
__slots__ = (
"_connections",
"_free",
"acquire_connection_timeout",
"connection_class",
"connection_kwargs",
"host",
@@ -970,6 +971,7 @@ class ClusterNode:
"port",
"response_callbacks",
"server_type",
"wait_for_connections",
)

def __init__(
@@ -980,6 +982,7 @@ def __init__(
*,
max_connections: int = 2**31,
connection_class: Type[Connection] = Connection,
wait_for_connections: bool = False,
**connection_kwargs: Any,
) -> None:
if host == "localhost":
@@ -996,9 +999,11 @@ def __init__(
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
self.acquire_connection_timeout = connection_kwargs.get("socket_timeout", 30)

self._connections: List[Connection] = []
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
self._free: asyncio.Queue[Connection] = asyncio.Queue()
self.wait_for_connections = wait_for_connections

def __repr__(self) -> str:
return (
@@ -1039,14 +1044,24 @@ async def disconnect(self) -> None:
if exc:
raise exc

def acquire_connection(self) -> Connection:
async def acquire_connection(self) -> Connection:
try:
return self._free.popleft()
except IndexError:
return self._free.get_nowait()
except asyncio.QueueEmpty:
if len(self._connections) < self.max_connections:
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
return connection
elif self.wait_for_connections:
try:
connection = await asyncio.wait_for(
self._free.get(), self.acquire_connection_timeout
)
return connection
except TimeoutError:
raise RedisTimeoutError(
"Timeout reached waiting for a free connection"
)

raise MaxConnectionsError()

@@ -1075,12 +1090,12 @@ async def parse_response(

async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = self.acquire_connection()
connection = await self.acquire_connection()
keys = kwargs.pop("keys", None)

response_from_cache = await connection._get_from_local_cache(args)
if response_from_cache is not None:
self._free.append(connection)
await self._free.put(connection)
return response_from_cache
else:
# Execute command
@@ -1094,11 +1109,11 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
return response
finally:
# Release connection
self._free.append(connection)
await self._free.put(connection)

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
connection = await self.acquire_connection()

# Execute command
await connection.send_packed_command(
@@ -1117,7 +1132,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
ret = True

# Release connection
self._free.append(connection)
await self._free.put(connection)

return ret

22 changes: 22 additions & 0 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -464,6 +464,28 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None:

await rc.aclose()

async def test_max_connections_waited(
self, create_redis: Callable[..., RedisCluster]
) -> None:
rc = await create_redis(
cls=RedisCluster, max_connections=10, wait_for_connections=True
)
for node in rc.get_nodes():
assert node.max_connections == 10

with mock.patch.object(Connection, "read_response") as read_response:

async def read_response_mocked(*args: Any, **kwargs: Any) -> None:
await asyncio.sleep(1)

read_response.side_effect = read_response_mocked

await asyncio.gather(
*(rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) for _ in range(20))
)
assert len(rc.get_default_node()._connections) == 10
await rc.aclose()

async def test_execute_command_errors(self, r: RedisCluster) -> None:
"""
Test that if no key is provided then exception should be raised.