Skip to content

Commit

Permalink
Convert to SocketPool
Browse files Browse the repository at this point in the history
  • Loading branch information
justmobilize committed Apr 24, 2024
1 parent 956d6a0 commit 646d60c
Showing 1 changed file with 84 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,71 +10,96 @@
* Author(s): ladyada
"""
from __future__ import annotations

try:
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from esp32spi.adafruit_esp32spi import ESP_SPIcontrol
except ImportError:
pass

# pylint: disable=no-name-in-module

import time
import gc
from micropython import const
from adafruit_esp32spi import adafruit_esp32spi
from adafruit_esp32spi import adafruit_esp32spi as esp32spi

_the_interface = None # pylint: disable=invalid-name
_global_socketpool = {}


def set_interface(iface):
"""Helper to set the global internet interface"""
global _the_interface # pylint: disable=global-statement, invalid-name
_the_interface = iface
class SocketPoolContants: # pylint: disable=too-few-public-methods
"""Helper class for the constants that are needed everywhere"""

SOCK_STREAM = const(0)
SOCK_DGRAM = const(1)
AF_INET = const(2)
NO_SOCKET_AVAIL = const(255)

SOCK_STREAM = const(0)
SOCK_DGRAM = const(1)
AF_INET = const(2)
NO_SOCKET_AVAIL = const(255)
MAX_PACKET = const(4000)

MAX_PACKET = const(4000)

class SocketPool(SocketPoolContants):
"""ESP32SPI SocketPool library"""

# pylint: disable=too-many-arguments, unused-argument
def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0):
"""Given a hostname and a port name, return a 'socket.getaddrinfo'
compatible list of tuples. Honestly, we ignore anything but host & port"""
if not isinstance(port, int):
raise ValueError("Port must be an integer")
ipaddr = _the_interface.get_host_by_name(host)
return [(AF_INET, socktype, proto, "", (ipaddr, port))]
def __new__(cls, iface: ESP_SPIcontrol):
# We want to make sure to return the same pool for the same interface
if iface not in _global_socketpool:
_global_socketpool[iface] = object.__new__(cls)
return _global_socketpool[iface]

def __init__(self, iface: ESP_SPIcontrol):
self._interface = iface

# pylint: enable=too-many-arguments, unused-argument
def getaddrinfo( # pylint: disable=too-many-arguments,unused-argument
self, host, port, family=0, socktype=0, proto=0, flags=0
):
"""Given a hostname and a port name, return a 'socket.getaddrinfo'
compatible list of tuples. Honestly, we ignore anything but host & port"""
if not isinstance(port, int):
raise ValueError("Port must be an integer")
ipaddr = self._interface.get_host_by_name(host)
return [(SocketPoolContants.AF_INET, socktype, proto, "", (ipaddr, port))]

def socket( # pylint: disable=redefined-builtin
self,
family=SocketPoolContants.AF_INET,
type=SocketPoolContants.SOCK_STREAM,
proto=0,
fileno=None,
):
"""Create a new socket and return it"""
return Socket(self, family, type, proto, fileno)


# pylint: disable=unused-argument, redefined-builtin, invalid-name
class socket:
class Socket:
"""A simplified implementation of the Python 'socket' class, for connecting
through an interface to a remote device"""

# pylint: disable=too-many-arguments
def __init__(
self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, socknum=None
def __init__( # pylint: disable=redefined-builtin,too-many-arguments,unused-argument
self,
socket_pool: SocketPool,
family: int = SocketPool.AF_INET,
type: int = SocketPool.SOCK_STREAM,
proto: int = 0,
fileno: Optional[int] = None,
):
if family != AF_INET:
if family != SocketPool.AF_INET:
raise ValueError("Only AF_INET family supported")
self._socket_pool = socket_pool
self._interface = self._socket_pool._interface
self._type = type
self._buffer = b""
self._socknum = socknum if socknum else _the_interface.get_socket()
self._socknum = self._interface.get_socket(reserve_socket=True)
self.settimeout(0)

# pylint: enable=too-many-arguments

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
while (
_the_interface.socket_status(self._socknum)
!= adafruit_esp32spi.SOCKET_CLOSED
):
while self._interface.socket_status(self._socknum) != esp32spi.SOCKET_CLOSED:
pass

def connect(self, address, conntype=None):
Expand All @@ -83,20 +108,20 @@ def connect(self, address, conntype=None):
depending on the underlying interface"""
host, port = address
if conntype is None:
conntype = _the_interface.TCP_MODE
if not _the_interface.socket_connect(
conntype = self._interface.TCP_MODE
if not self._interface.socket_connect(
self._socknum, host, port, conn_mode=conntype
):
raise ConnectionError("Failed to connect to host", host)
self._buffer = b""

def send(self, data): # pylint: disable=no-self-use
def send(self, data):
"""Send some data to the socket."""
if self._type is SOCK_DGRAM:
conntype = _the_interface.UDP_MODE
if self._type is SocketPool.SOCK_DGRAM:
conntype = self._interface.UDP_MODE
else:
conntype = _the_interface.TCP_MODE
_the_interface.socket_write(self._socknum, data, conn_mode=conntype)
conntype = self._interface.TCP_MODE
self._interface.socket_write(self._socknum, data, conn_mode=conntype)
gc.collect()

def recv(self, bufsize: int) -> bytes:
Expand Down Expand Up @@ -140,7 +165,7 @@ def recv_into(self, buffer, nbytes: int = 0):
num_avail = self._available()
if num_avail > 0:
last_read_time = time.monotonic()
bytes_read = _the_interface.socket_read(
bytes_read = self._interface.socket_read(
self._socknum, min(num_to_read, num_avail)
)
buffer[num_read : num_read + len(bytes_read)] = bytes_read
Expand All @@ -162,43 +187,42 @@ def settimeout(self, value):

def _available(self):
"""Returns how many bytes of data are available to be read (up to the MAX_PACKET length)"""
if self._socknum != NO_SOCKET_AVAIL:
return min(_the_interface.socket_available(self._socknum), MAX_PACKET)
if self._socknum != SocketPool.NO_SOCKET_AVAIL:
return min(
self._interface.socket_available(self._socknum), SocketPool.MAX_PACKET
)
return 0

def _connected(self):
"""Whether or not we are connected to the socket"""
if self._socknum == NO_SOCKET_AVAIL:
if self._socknum == SocketPool.NO_SOCKET_AVAIL:
return False
if self._available():
return True
status = _the_interface.socket_status(self._socknum)
status = self._interface.socket_status(self._socknum)
result = status not in (
adafruit_esp32spi.SOCKET_LISTEN,
adafruit_esp32spi.SOCKET_CLOSED,
adafruit_esp32spi.SOCKET_FIN_WAIT_1,
adafruit_esp32spi.SOCKET_FIN_WAIT_2,
adafruit_esp32spi.SOCKET_TIME_WAIT,
adafruit_esp32spi.SOCKET_SYN_SENT,
adafruit_esp32spi.SOCKET_SYN_RCVD,
adafruit_esp32spi.SOCKET_CLOSE_WAIT,
esp32spi.SOCKET_LISTEN,
esp32spi.SOCKET_CLOSED,
esp32spi.SOCKET_FIN_WAIT_1,
esp32spi.SOCKET_FIN_WAIT_2,
esp32spi.SOCKET_TIME_WAIT,
esp32spi.SOCKET_SYN_SENT,
esp32spi.SOCKET_SYN_RCVD,
esp32spi.SOCKET_CLOSE_WAIT,
)
if not result:
self.close()
self._socknum = NO_SOCKET_AVAIL
self._socknum = SocketPool.NO_SOCKET_AVAIL
return result

def close(self):
"""Close the socket, after reading whatever remains"""
_the_interface.socket_close(self._socknum)
self._interface.socket_close(self._socknum)


class timeout(TimeoutError):
class timeout(TimeoutError): # pylint: disable=invalid-name
"""TimeoutError class. An instance of this error will be raised by recv_into() if
the timeout has elapsed and we haven't received any data yet."""

def __init__(self, msg):
super().__init__(msg)


# pylint: enable=unused-argument, redefined-builtin, invalid-name

0 comments on commit 646d60c

Please sign in to comment.