Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to using ConnectionManager #197

Merged
merged 14 commits into from Feb 29, 2024
Merged
6 changes: 5 additions & 1 deletion .gitignore
Expand Up @@ -47,5 +47,9 @@ _build
.vscode
*~

# tox local cache
# tox-specific files
.tox
build

# coverage-specific files
.coverage
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -39,4 +39,4 @@ repos:
types: [python]
files: "^tests/"
args:
- --disable=missing-docstring,consider-using-f-string,duplicate-code
- --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code
147 changes: 22 additions & 125 deletions adafruit_minimqtt/adafruit_minimqtt.py
Expand Up @@ -26,12 +26,21 @@
* Adafruit CircuitPython firmware for the supported boards:
https://github.com/adafruit/circuitpython/releases

* Adafruit's Connection Manager library:
https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager

"""
import errno
import struct
import time
from random import randint

from adafruit_connection_manager import (
get_connection_manager,
SocketGetOSError,
SocketConnectMemoryError,
)

try:
from typing import List, Optional, Tuple, Type, Union
except ImportError:
Expand Down Expand Up @@ -78,68 +87,19 @@
_default_sock = None # pylint: disable=invalid-name
_fake_context = None # pylint: disable=invalid-name

TemporaryError = (SocketGetOSError, SocketConnectMemoryError)


class MMQTTException(Exception):
"""MiniMQTT Exception class."""

# pylint: disable=unnecessary-pass
# pass


class TemporaryError(Exception):
"""Temporary error class used for handling reconnects."""


# Legacy ESP32SPI Socket API
def set_socket(sock, iface=None) -> None:
"""Legacy API for setting the socket and network interface.

:param sock: socket object.
:param iface: internet interface object

"""
global _default_sock # pylint: disable=invalid-name, global-statement
global _fake_context # pylint: disable=invalid-name, global-statement
_default_sock = sock
if iface:
_default_sock.set_interface(iface)
_fake_context = _FakeSSLContext(iface)


class _FakeSSLSocket:
def __init__(self, socket, tls_mode) -> None:
self._socket = socket
self._mode = tls_mode
self.settimeout = socket.settimeout
self.send = socket.send
self.recv = socket.recv
self.close = socket.close

def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
try:
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error


class _FakeSSLContext:
def __init__(self, iface) -> None:
self._iface = iface

def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket:
"""Return the same socket"""
# pylint: disable=unused-argument
return _FakeSSLSocket(socket, self._iface.TLS_MODE)


class NullLogger:
"""Fake logger class that does not do anything"""

# pylint: disable=unused-argument
def nothing(self, msg: str, *args) -> None:
"""no action"""
pass

def __init__(self) -> None:
for log_level in ["debug", "info", "warning", "error", "critical"]:
Expand Down Expand Up @@ -194,6 +154,7 @@ def __init__(
user_data=None,
use_imprecise_time: Optional[bool] = None,
) -> None:
self._connection_manager = get_connection_manager(socket_pool)
self._socket_pool = socket_pool
self._ssl_context = ssl_context
self._sock = None
Expand Down Expand Up @@ -300,77 +261,6 @@ def get_monotonic_time(self) -> float:

return time.monotonic()

# pylint: disable=too-many-branches
def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1):
"""Obtains a new socket and connects to a broker.

:param str host: Desired broker hostname
:param int port: Desired broker port
:param int timeout: Desired socket timeout, in seconds
"""
# For reconnections - check if we're using a socket already and close it
if self._sock:
self._sock.close()
self._sock = None

# Legacy API - use the interface's socket instead of a passed socket pool
if self._socket_pool is None:
self._socket_pool = _default_sock

# Legacy API - fake the ssl context
if self._ssl_context is None:
self._ssl_context = _fake_context

if not isinstance(port, int):
raise RuntimeError("Port must be an integer")

if self._is_ssl and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
)

if self._is_ssl:
self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}")
else:
self.logger.info(f"Establishing an INSECURE connection to {host}:{port}")

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

try:
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
# Do not consider this for back-off.
self.logger.warning(
f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}"
)
raise TemporaryError from exc

connect_host = addr_info[-1][0]
if self._is_ssl:
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout)

last_exception = None
try:
sock.connect((connect_host, port))
except MemoryError as exc:
sock.close()
self.logger.warning(f"Failed to allocate memory for connect: {exc}")
# Do not consider this for back-off.
raise TemporaryError from exc
except OSError as exc:
sock.close()
last_exception = exc

if last_exception:
raise last_exception

self._backwards_compatible_sock = not hasattr(sock, "recv_into")
return sock

def __enter__(self):
return self

Expand Down Expand Up @@ -593,8 +483,15 @@ def _connect(
time.sleep(self._reconnect_timeout)

# Get a new socket
self._sock = self._get_connect_socket(
self.broker, self.port, timeout=self._socket_timeout
self._sock = self._connection_manager.get_socket(
self.broker,
self.port,
"mqtt:",
timeout=self._socket_timeout,
is_ssl=self._is_ssl,
ssl_context=self._ssl_context,
max_retries=1, # setting to 1 since we want to handle backoff internally
exception_passthrough=True,
)

# Fixed Header
Expand Down Expand Up @@ -689,7 +586,7 @@ def disconnect(self) -> None:
except RuntimeError as e:
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
self.logger.debug("Closing socket")
self._sock.close()
self._connection_manager.free_socket(self._sock)
self._is_connected = False
self._subscribed_topics = []
if self.on_disconnect is not None:
Expand Down
17 changes: 17 additions & 0 deletions conftest.py
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: 2023 Justin Myers for Adafruit Industries
brentru marked this conversation as resolved.
Show resolved Hide resolved
#
# SPDX-License-Identifier: Unlicense

""" PyTest Setup """

import pytest
import adafruit_connection_manager


@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
"""Reset the ConnectionManager, since it's a singlton and will hold data"""
monkeypatch.setattr(
"adafruit_minimqtt.adafruit_minimqtt.get_connection_manager",
adafruit_connection_manager.ConnectionManager,
)
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -3,3 +3,4 @@
# SPDX-License-Identifier: Unlicense

Adafruit-Blinka
Adafruit-Circuitpython-ConnectionManager@git+https://github.com/justmobilize/Adafruit_CircuitPython_ConnectionManager@connection-manager
justmobilize marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 33 additions & 4 deletions tox.ini
Expand Up @@ -3,9 +3,38 @@
# SPDX-License-Identifier: MIT

[tox]
envlist = py39
envlist = py311

[testenv]
changedir = {toxinidir}/tests
deps = pytest==6.2.5
commands = pytest -v
description = run tests
deps =
pytest==7.4.3
pytest-subtests==0.11.0
commands = pytest

[testenv:coverage]
description = run coverage
deps =
pytest==7.4.3
pytest-cov==4.1.0
pytest-subtests==0.11.0
package = editable
commands =
coverage run --source=. --omit=tests/* --branch {posargs} -m pytest
coverage report
coverage html

[testenv:lint]
description = run linters
deps =
pre-commit==3.6.0
skip_install = true
commands = pre-commit run {posargs}

[testenv:docs]
description = build docs
deps =
-r requirements.txt
-r docs/requirements.txt
skip_install = true
commands = sphinx-build -E -W -b html docs/. _build/html