diff --git a/adafruit_esp32spi/adafruit_esp32spi_socket.py b/adafruit_esp32spi/adafruit_esp32spi_socketpool.py similarity index 56% rename from adafruit_esp32spi/adafruit_esp32spi_socket.py rename to adafruit_esp32spi/adafruit_esp32spi_socketpool.py index 347b330..5394c6c 100644 --- a/adafruit_esp32spi/adafruit_esp32spi_socket.py +++ b/adafruit_esp32spi/adafruit_esp32spi_socketpool.py @@ -10,71 +10,96 @@ * Author(s): ladyada """ +from __future__ import annotations + +try: + from typing import TYPE_CHECKING, Optional + + if TYPE_CHECKING: + from esp32spi.adafruit_esp32spi import ESP_SPIcontrol +except ImportError: + pass -# pylint: disable=no-name-in-module import time import gc from micropython import const -from adafruit_esp32spi import adafruit_esp32spi +from adafruit_esp32spi import adafruit_esp32spi as esp32spi -_the_interface = None # pylint: disable=invalid-name +_global_socketpool = {} -def set_interface(iface): - """Helper to set the global internet interface""" - global _the_interface # pylint: disable=global-statement, invalid-name - _the_interface = iface +class SocketPoolContants: # pylint: disable=too-few-public-methods + """Helper class for the constants that are needed everywhere""" + SOCK_STREAM = const(0) + SOCK_DGRAM = const(1) + AF_INET = const(2) + NO_SOCKET_AVAIL = const(255) -SOCK_STREAM = const(0) -SOCK_DGRAM = const(1) -AF_INET = const(2) -NO_SOCKET_AVAIL = const(255) + MAX_PACKET = const(4000) -MAX_PACKET = const(4000) +class SocketPool(SocketPoolContants): + """ESP32SPI SocketPool library""" -# pylint: disable=too-many-arguments, unused-argument -def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0): - """Given a hostname and a port name, return a 'socket.getaddrinfo' - compatible list of tuples. Honestly, we ignore anything but host & port""" - if not isinstance(port, int): - raise ValueError("Port must be an integer") - ipaddr = _the_interface.get_host_by_name(host) - return [(AF_INET, socktype, proto, "", (ipaddr, port))] + def __new__(cls, iface: ESP_SPIcontrol): + # We want to make sure to return the same pool for the same interface + if iface not in _global_socketpool: + _global_socketpool[iface] = object.__new__(cls) + return _global_socketpool[iface] + def __init__(self, iface: ESP_SPIcontrol): + self._interface = iface -# pylint: enable=too-many-arguments, unused-argument + def getaddrinfo( # pylint: disable=too-many-arguments,unused-argument + self, host, port, family=0, socktype=0, proto=0, flags=0 + ): + """Given a hostname and a port name, return a 'socket.getaddrinfo' + compatible list of tuples. Honestly, we ignore anything but host & port""" + if not isinstance(port, int): + raise ValueError("Port must be an integer") + ipaddr = self._interface.get_host_by_name(host) + return [(SocketPoolContants.AF_INET, socktype, proto, "", (ipaddr, port))] + + def socket( # pylint: disable=redefined-builtin + self, + family=SocketPoolContants.AF_INET, + type=SocketPoolContants.SOCK_STREAM, + proto=0, + fileno=None, + ): + """Create a new socket and return it""" + return Socket(self, family, type, proto, fileno) -# pylint: disable=unused-argument, redefined-builtin, invalid-name -class socket: +class Socket: """A simplified implementation of the Python 'socket' class, for connecting through an interface to a remote device""" - # pylint: disable=too-many-arguments - def __init__( - self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, socknum=None + def __init__( # pylint: disable=redefined-builtin,too-many-arguments,unused-argument + self, + socket_pool: SocketPool, + family: int = SocketPool.AF_INET, + type: int = SocketPool.SOCK_STREAM, + proto: int = 0, + fileno: Optional[int] = None, ): - if family != AF_INET: + if family != SocketPool.AF_INET: raise ValueError("Only AF_INET family supported") + self._socket_pool = socket_pool + self._interface = self._socket_pool._interface self._type = type self._buffer = b"" - self._socknum = socknum if socknum else _the_interface.get_socket() + self._socknum = self._interface.get_socket(reserve_socket=True) self.settimeout(0) - # pylint: enable=too-many-arguments - def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() - while ( - _the_interface.socket_status(self._socknum) - != adafruit_esp32spi.SOCKET_CLOSED - ): + while self._interface.socket_status(self._socknum) != esp32spi.SOCKET_CLOSED: pass def connect(self, address, conntype=None): @@ -83,20 +108,20 @@ def connect(self, address, conntype=None): depending on the underlying interface""" host, port = address if conntype is None: - conntype = _the_interface.TCP_MODE - if not _the_interface.socket_connect( + conntype = self._interface.TCP_MODE + if not self._interface.socket_connect( self._socknum, host, port, conn_mode=conntype ): raise ConnectionError("Failed to connect to host", host) self._buffer = b"" - def send(self, data): # pylint: disable=no-self-use + def send(self, data): """Send some data to the socket.""" - if self._type is SOCK_DGRAM: - conntype = _the_interface.UDP_MODE + if self._type is SocketPool.SOCK_DGRAM: + conntype = self._interface.UDP_MODE else: - conntype = _the_interface.TCP_MODE - _the_interface.socket_write(self._socknum, data, conn_mode=conntype) + conntype = self._interface.TCP_MODE + self._interface.socket_write(self._socknum, data, conn_mode=conntype) gc.collect() def recv(self, bufsize: int) -> bytes: @@ -140,7 +165,7 @@ def recv_into(self, buffer, nbytes: int = 0): num_avail = self._available() if num_avail > 0: last_read_time = time.monotonic() - bytes_read = _the_interface.socket_read( + bytes_read = self._interface.socket_read( self._socknum, min(num_to_read, num_avail) ) buffer[num_read : num_read + len(bytes_read)] = bytes_read @@ -162,43 +187,42 @@ def settimeout(self, value): def _available(self): """Returns how many bytes of data are available to be read (up to the MAX_PACKET length)""" - if self._socknum != NO_SOCKET_AVAIL: - return min(_the_interface.socket_available(self._socknum), MAX_PACKET) + if self._socknum != SocketPool.NO_SOCKET_AVAIL: + return min( + self._interface.socket_available(self._socknum), SocketPool.MAX_PACKET + ) return 0 def _connected(self): """Whether or not we are connected to the socket""" - if self._socknum == NO_SOCKET_AVAIL: + if self._socknum == SocketPool.NO_SOCKET_AVAIL: return False if self._available(): return True - status = _the_interface.socket_status(self._socknum) + status = self._interface.socket_status(self._socknum) result = status not in ( - adafruit_esp32spi.SOCKET_LISTEN, - adafruit_esp32spi.SOCKET_CLOSED, - adafruit_esp32spi.SOCKET_FIN_WAIT_1, - adafruit_esp32spi.SOCKET_FIN_WAIT_2, - adafruit_esp32spi.SOCKET_TIME_WAIT, - adafruit_esp32spi.SOCKET_SYN_SENT, - adafruit_esp32spi.SOCKET_SYN_RCVD, - adafruit_esp32spi.SOCKET_CLOSE_WAIT, + esp32spi.SOCKET_LISTEN, + esp32spi.SOCKET_CLOSED, + esp32spi.SOCKET_FIN_WAIT_1, + esp32spi.SOCKET_FIN_WAIT_2, + esp32spi.SOCKET_TIME_WAIT, + esp32spi.SOCKET_SYN_SENT, + esp32spi.SOCKET_SYN_RCVD, + esp32spi.SOCKET_CLOSE_WAIT, ) if not result: self.close() - self._socknum = NO_SOCKET_AVAIL + self._socknum = SocketPool.NO_SOCKET_AVAIL return result def close(self): """Close the socket, after reading whatever remains""" - _the_interface.socket_close(self._socknum) + self._interface.socket_close(self._socknum) -class timeout(TimeoutError): +class timeout(TimeoutError): # pylint: disable=invalid-name """TimeoutError class. An instance of this error will be raised by recv_into() if the timeout has elapsed and we haven't received any data yet.""" def __init__(self, msg): super().__init__(msg) - - -# pylint: enable=unused-argument, redefined-builtin, invalid-name