diff --git a/aioredis/connection.py b/aioredis/connection.py index 4a2f182a3..74877bd8b 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -23,6 +23,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -214,7 +215,7 @@ def on_connect(self, connection: "Connection"): async def can_read(self, timeout: float) -> bool: raise NotImplementedError() - async def read_response(self) -> Union[EncodableT, ResponseError, None]: + async def read_response(self) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: raise NotImplementedError() @@ -516,11 +517,12 @@ async def read_from_socket( return False raise ConnectionError(f"Error while reading from socket: {ex.args}") - async def read_response(self) -> EncodableT: + async def read_response(self) -> Union[EncodableT, List[EncodableT]]: if not self._stream or not self._reader: self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + response: Union[EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -538,12 +540,13 @@ async def read_response(self) -> EncodableT: if isinstance(response, ConnectionError): raise response elif ( - isinstance(response, list) # type: ignore[unreachable] + isinstance(response, list) and response and isinstance(response[0], ConnectionError) ): raise response[0] - return response + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) DefaultParser: Type[Union[PythonParser, HiredisParser]]