Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed infinitely recursive health checks #3557

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
@@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:

async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)

async def connect_check_health(self, check_health: bool = True):
if self.is_connected:
return
try:
@@ -302,7 +305,7 @@ async def connect(self):
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect()
await self.on_connect_check_health(check_health=check_health)
else:
# Use the passed function redis_connect_func
(
@@ -341,6 +344,9 @@ def get_protocol(self):

async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)

async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser

@@ -398,7 +404,7 @@ async def on_connect(self) -> None:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol)
await self.send_command("HELLO", self.protocol, check_health=check_health)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
@@ -407,18 +413,35 @@ async def on_connect(self) -> None:

# if a client_name is given, set it
if self.client_name:
await self.send_command("CLIENT", "SETNAME", self.client_name)
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
if self.lib_version:
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
await self.send_command("SELECT", self.db, check_health=check_health)

# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -480,8 +503,8 @@ async def send_packed_command(
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect()
elif check_health:
await self.connect_check_health(check_health=False)
if check_health:
await self.check_health()

try:
37 changes: 30 additions & 7 deletions redis/connection.py
Original file line number Diff line number Diff line change
@@ -372,6 +372,9 @@ def set_parser(self, parser_class):

def connect(self):
"Connects to the Redis server if not already connected"
self.connect_check_health(check_health=True)

def connect_check_health(self, check_health: bool = True):
if self._sock:
return
try:
@@ -387,7 +390,7 @@ def connect(self):
try:
if self.redis_connect_func is None:
# Use the default on_connect function
self.on_connect()
self.on_connect_check_health(check_health=check_health)
else:
# Use the passed function redis_connect_func
self.redis_connect_func(self)
@@ -417,6 +420,9 @@ def _error_message(self, exception):
return format_error_message(self._host_error(), exception)

def on_connect(self):
self.on_connect_check_health(check_health=True)

def on_connect_check_health(self, check_health: bool = True):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
parser = self._parser
@@ -475,7 +481,7 @@ def on_connect(self):
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
self.send_command("HELLO", self.protocol, check_health=check_health)
self.handshake_metadata = self.read_response()
if (
self.handshake_metadata.get(b"proto") != self.protocol
@@ -485,28 +491,45 @@ def on_connect(self):

# if a client_name is given, set it
if self.client_name:
self.send_command("CLIENT", "SETNAME", self.client_name)
self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

try:
# set the library name and version
if self.lib_name:
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
self.read_response()
except ResponseError:
pass

try:
if self.lib_version:
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
self.read_response()
except ResponseError:
pass

# if a database is specified, switch to it
if self.db:
self.send_command("SELECT", self.db)
self.send_command("SELECT", self.db, check_health=check_health)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

@@ -548,7 +571,7 @@ def check_health(self):
def send_packed_command(self, command, check_health=True):
"""Send an already packed command to the Redis server"""
if not self._sock:
self.connect()
self.connect_check_health(check_health=False)
# guard against health check recursion
if check_health:
self.check_health()
Loading