From 8bb706d09fa2cdaa3e2a3caf830dc92b26add4cc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 15 Oct 2022 12:25:37 -1000 Subject: [PATCH] feat: add a retry_bluetooth_connection_error decorator (#53) --- src/bleak_retry_connector/__init__.py | 95 +++++++++++++++++++-------- tests/test_init.py | 50 ++++++++++++++ 2 files changed, 117 insertions(+), 28 deletions(-) diff --git a/src/bleak_retry_connector/__init__.py b/src/bleak_retry_connector/__init__.py index 96c3354..cf8358c 100644 --- a/src/bleak_retry_connector/__init__.py +++ b/src/bleak_retry_connector/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import cast + __version__ = "2.2.0" @@ -9,7 +11,7 @@ import platform import time from collections.abc import Callable, Generator -from typing import Any +from typing import Any, TypeVar import async_timeout from bleak import BleakClient, BleakError @@ -20,6 +22,7 @@ DISCONNECT_TIMEOUT = 5 IS_LINUX = platform.system() == "Linux" +DEFAULT_ATTEMPTS = 2 if IS_LINUX: from .dbus import disconnect_devices @@ -47,6 +50,7 @@ "close_stale_connections", "get_device", "get_device_by_adapter", + "retry_bluetooth_connection_error", "BleakClientWithServiceCache", "BleakAbortedError", "BleakNotFoundError", @@ -176,6 +180,15 @@ def address_to_bluez_path(address: str, adapter: str | None = None) -> str: return f"/org/bluez/{adapter or 'hciX'}/dev_{address.upper().replace(':', '_')}" +def calculate_backoff_time(exc: Exception) -> float: + """Calculate the backoff time based on the exception.""" + if isinstance( + exc, (BleakDBusError, EOFError, asyncio.TimeoutError, BrokenPipeError) + ): + return BLEAK_DBUS_BACKOFF_TIME + return BLEAK_BACKOFF_TIME + + async def get_device(address: str) -> BLEDevice | None: """Get the device.""" if not IS_LINUX: @@ -489,7 +502,8 @@ def _raise_if_needed(name: str, description: str, exc: Exception) -> None: attempt, rssi, ) - await wait_for_disconnect(device, BLEAK_DBUS_BACKOFF_TIME) + backoff_time = calculate_backoff_time(exc) + await wait_for_disconnect(device, backoff_time) _raise_if_needed(name, description, exc) except BrokenPipeError as exc: # BrokenPipeError is raised by dbus-next when the device disconnects @@ -515,17 +529,18 @@ def _raise_if_needed(name: str, description: str, exc: Exception) -> None: _raise_if_needed(name, description, exc) except EOFError as exc: transient_errors += 1 + backoff_time = calculate_backoff_time(exc) if debug_enabled: _LOGGER.debug( "%s - %s: Failed to connect: %s, backing off: %s (attempt: %s, last rssi: %s)", name, description, str(exc), - BLEAK_DBUS_BACKOFF_TIME, + backoff_time, attempt, rssi, ) - await wait_for_disconnect(device, BLEAK_DBUS_BACKOFF_TIME) + await wait_for_disconnect(device, backoff_time) _raise_if_needed(name, description, exc) except BLEAK_EXCEPTIONS as exc: bleak_error = str(exc) @@ -533,30 +548,18 @@ def _raise_if_needed(name: str, description: str, exc: Exception) -> None: transient_errors += 1 else: connect_errors += 1 - if isinstance(exc, BleakDBusError): - if debug_enabled: - _LOGGER.debug( - "%s - %s: Failed to connect: %s, backing off: %s (attempt: %s, last rssi: %s)", - name, - description, - bleak_error, - BLEAK_DBUS_BACKOFF_TIME, - attempt, - rssi, - ) - await wait_for_disconnect(device, BLEAK_DBUS_BACKOFF_TIME) - else: - if debug_enabled: - _LOGGER.debug( - "%s - %s: Failed to connect: %s, backing off: %s (attempt: %s, last rssi: %s)", - name, - description, - bleak_error, - BLEAK_BACKOFF_TIME, - attempt, - rssi, - ) - await wait_for_disconnect(device, BLEAK_BACKOFF_TIME) + backoff_time = calculate_backoff_time(exc) + if debug_enabled: + _LOGGER.debug( + "%s - %s: Failed to connect: %s, backing off: %s (attempt: %s, last rssi: %s)", + name, + description, + bleak_error, + backoff_time, + attempt, + rssi, + ) + await wait_for_disconnect(device, backoff_time) _raise_if_needed(name, description, exc) else: if debug_enabled: @@ -573,3 +576,39 @@ def _raise_if_needed(name: str, description: str, exc: Exception) -> None: await asyncio.sleep(0) raise RuntimeError("This should never happen") + + +WrapFuncType = TypeVar("WrapFuncType", bound=Callable[..., Any]) + + +def retry_bluetooth_connection_error(attempts: int = DEFAULT_ATTEMPTS) -> WrapFuncType: + """Define a wrapper to retry on bluetooth connection error.""" + + def _decorator_retry_bluetooth_connection_error(func: WrapFuncType) -> WrapFuncType: + """Define a wrapper to retry on bleak error. + + The accessory is allowed to disconnect us any time so + we need to retry the operation. + """ + + async def _async_wrap_bluetooth_connection_error_retry( + *args: Any, **kwargs: Any + ) -> Any: + for attempt in range(attempts): + try: + return await func(*args, **kwargs) + except BLEAK_EXCEPTIONS as ex: + backoff_time = calculate_backoff_time(ex) + if attempt == attempts - 1: + raise + _LOGGER.debug( + "Bleak error calling %s, backing off: %s, retrying...", + func, + backoff_time, + exc_info=True, + ) + await asyncio.sleep(backoff_time) + + return cast(WrapFuncType, _async_wrap_bluetooth_connection_error_retry) + + return cast(WrapFuncType, _decorator_retry_bluetooth_connection_error) diff --git a/tests/test_init.py b/tests/test_init.py index c51db03..cab18ed 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -7,19 +7,24 @@ from bleak.backends.bluezdbus import defs from bleak.backends.device import BLEDevice from bleak.backends.service import BleakGATTServiceCollection +from bleak.exc import BleakDBusError import bleak_retry_connector from bleak_retry_connector import ( + BLEAK_BACKOFF_TIME, + BLEAK_DBUS_BACKOFF_TIME, MAX_TRANSIENT_ERRORS, BleakAbortedError, BleakClientWithServiceCache, BleakConnectionError, BleakNotFoundError, ble_device_has_changed, + calculate_backoff_time, establish_connection, get_connected_devices, get_device, get_device_by_adapter, + retry_bluetooth_connection_error, ) @@ -1427,3 +1432,48 @@ def __init__(self): assert device_hci0.details["path"] == "/org/bluez/hci0/dev_FA_23_9D_AA_45_46" assert device_hci1 is not None assert device_hci1.details["path"] == "/org/bluez/hci1/dev_FA_23_9D_AA_45_46" + + +def test_calculate_backoff_time(): + """Test that the backoff time is calculated correctly.""" + assert calculate_backoff_time(Exception()) == BLEAK_BACKOFF_TIME + assert ( + calculate_backoff_time(BleakDBusError(MagicMock(), MagicMock())) + == BLEAK_DBUS_BACKOFF_TIME + ) + + +@pytest.mark.asyncio +async def test_retry_bluetooth_connection_error(): + """Test that the retry_bluetooth_connection_error decorator works correctly.""" + + @retry_bluetooth_connection_error() # type: ignore[misc] + async def test_function(): + raise BleakDBusError(MagicMock(), MagicMock()) + + with patch( + "bleak_retry_connector.calculate_backoff_time" + ) as mock_calculate_backoff_time: + mock_calculate_backoff_time.return_value = 0 + with pytest.raises(BleakDBusError): + await test_function() + + assert mock_calculate_backoff_time.call_count == 2 + + +@pytest.mark.asyncio +async def test_retry_bluetooth_connection_error_non_default_max_attempts(): + """Test that the retry_bluetooth_connection_error decorator works correctly with a different number of retries.""" + + @retry_bluetooth_connection_error(4) # type: ignore[misc] + async def test_function(): + raise BleakDBusError(MagicMock(), MagicMock()) + + with patch( + "bleak_retry_connector.calculate_backoff_time" + ) as mock_calculate_backoff_time: + mock_calculate_backoff_time.return_value = 0 + with pytest.raises(BleakDBusError): + await test_function() + + assert mock_calculate_backoff_time.call_count == 4