Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions src/infuse_iot/rpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#!/usr/bin/env python3

import ctypes
import random
from typing import Callable

from infuse_iot import rpc
from infuse_iot.common import InfuseType
from infuse_iot.epacket.packet import Auth, PacketOutput, PacketReceived
from infuse_iot.socket_comms import (
ClientNotification,
ClientNotificationConnectionDropped,
ClientNotificationEpacketReceived,
GatewayRequestEpacketSend,
LocalClient,
)


class RpcClient:
def __init__(
self,
client: LocalClient,
max_payload: int,
infuse_id: int,
rx_cb: Callable[[ClientNotification], None] | None = None,
):
self._request_id = random.randint(0, 2**31 - 1)
self._client = client
self._id = infuse_id
self._max_payload = max_payload
self._rx_cb = rx_cb

def _finalise_command(
self, rpc_rsp: PacketReceived, rsp_decoder: Callable[[bytes], ctypes.LittleEndianStructure]
) -> tuple[rpc.ResponseHeader, ctypes.LittleEndianStructure]:
# Convert response bytes back to struct form
rsp_header = rpc.ResponseHeader.from_buffer_copy(rpc_rsp.payload)
rsp_payload = rpc_rsp.payload[ctypes.sizeof(rpc.ResponseHeader) :]
rsp_data = rsp_decoder(rsp_payload)
return (rsp_header, rsp_data)

def _client_recv(self) -> ClientNotification | None:
rsp = self._client.receive()
if rsp is not None and self._rx_cb is not None:
self._rx_cb(rsp)
return rsp

def _wait_data_ack(self) -> PacketReceived:
while True:
rsp = self._client_recv()
if rsp is None:
continue
if not isinstance(rsp, ClientNotificationEpacketReceived):
continue
if rsp.epacket.ptype == InfuseType.RPC_RSP:
rsp_header = rpc.ResponseHeader.from_buffer_copy(rsp.epacket.payload)
if rsp_header.request_id == self._request_id:
return rsp.epacket
elif rsp.epacket.ptype != InfuseType.RPC_DATA_ACK:
continue
data_ack = rpc.DataAck.from_buffer_copy(rsp.epacket.payload)
# Response to the request we sent
if data_ack.request_id != self._request_id:
continue
return rsp.epacket

def _wait_rpc_rsp(self) -> PacketReceived:
# Wait for responses
while True:
rsp = self._client_recv()
if rsp is None:
continue
if not isinstance(rsp, ClientNotificationEpacketReceived):
continue
# RPC response packet
if rsp.epacket.ptype != InfuseType.RPC_RSP:
continue
rsp_header = rpc.ResponseHeader.from_buffer_copy(rsp.epacket.payload)
# Response to the request we sent
if rsp_header.request_id != self._request_id:
continue
return rsp.epacket

def run_data_send_cmd(
self,
cmd_id: int,
auth: Auth,
params: bytes,
data: bytes,
progress_cb: Callable[[int], None] | None,
rsp_decoder: Callable[[bytes], ctypes.LittleEndianStructure],
) -> tuple[rpc.ResponseHeader, ctypes.LittleEndianStructure]:
self._request_id += 1
ack_period = 1
header = rpc.RequestHeader(self._request_id, cmd_id) # type: ignore
data_hdr = rpc.RequestDataHeader(len(data), ack_period)

request_packet = bytes(header) + bytes(data_hdr) + params
pkt = PacketOutput(
self._id,
auth,
InfuseType.RPC_CMD,
request_packet,
)
req = GatewayRequestEpacketSend(pkt)
self._client.send(req)

# Wait for initial ACK
recv = self._wait_data_ack()
if recv.ptype == InfuseType.RPC_RSP:
return self._finalise_command(recv, rsp_decoder)

# Send data payloads with maximum interface size
ack_cnt = -ack_period
offset = 0
size = self._max_payload - ctypes.sizeof(rpc.DataHeader)
# Round payload down to multiple of 4 bytes
size -= size % 4
while len(data) > 0:
size = min(size, len(data))
payload = data[:size]

hdr = rpc.DataHeader(self._request_id, offset)
pkt_bytes = bytes(hdr) + payload
pkt = PacketOutput(
self._id,
auth,
InfuseType.RPC_DATA,
pkt_bytes,
)
self._client.send(GatewayRequestEpacketSend(pkt))
ack_cnt += 1

# Wait for ACKs at the period
if ack_cnt == ack_period:
self._wait_data_ack()
ack_cnt = 0

offset += size
data = data[size:]
if progress_cb:
progress_cb(offset)

recv = self._wait_rpc_rsp()
return self._finalise_command(recv, rsp_decoder)

def run_data_recv_cmd(
self,
cmd_id: int,
auth: Auth,
params: bytes,
recv_cb: Callable[[int, bytes], None],
rsp_decoder: Callable[[bytes], ctypes.LittleEndianStructure],
) -> tuple[rpc.ResponseHeader, ctypes.LittleEndianStructure]:
self._request_id += 1
header = rpc.RequestHeader(self._request_id, cmd_id)
data_hdr = rpc.RequestDataHeader(0xFFFFFFFF, 0)

request_packet = bytes(header) + bytes(data_hdr) + params
pkt = PacketOutput(
self._id,
auth,
InfuseType.RPC_CMD,
request_packet,
)
req = GatewayRequestEpacketSend(pkt)
self._client.send(req)

while True:
rsp = self._client_recv()
if rsp is None:
continue
if isinstance(rsp, ClientNotificationConnectionDropped):
raise ConnectionAbortedError
if not isinstance(rsp, ClientNotificationEpacketReceived):
continue
if rsp.epacket.ptype == InfuseType.RPC_RSP:
rsp_header = rpc.ResponseHeader.from_buffer_copy(rsp.epacket.payload)
# Response to the request we sent
if rsp_header.request_id != self._request_id:
continue
# Convert response bytes back to struct form
rsp_payload = rsp.epacket.payload[ctypes.sizeof(rpc.ResponseHeader) :]
rsp_data = rsp_decoder(rsp_payload)
return (rsp_header, rsp_data)

if rsp.epacket.ptype != InfuseType.RPC_DATA:
continue
data = rpc.DataHeader.from_buffer_copy(rsp.epacket.payload)
# Response to the request we sent
if data.request_id != self._request_id:
continue

recv_cb(data.offset, rsp.epacket.payload[ctypes.sizeof(rpc.DataHeader) :])

def run_standard_cmd(
self, cmd_id: int, auth: Auth, params: bytes, rsp_decoder: Callable[[bytes], ctypes.LittleEndianStructure]
) -> tuple[rpc.ResponseHeader, ctypes.LittleEndianStructure]:
self._request_id += 1
header = rpc.RequestHeader(self._request_id, cmd_id) # type: ignore

request_packet = bytes(header) + params
pkt = PacketOutput(
self._id,
auth,
InfuseType.RPC_CMD,
request_packet,
)
req = GatewayRequestEpacketSend(pkt)
self._client.send(req)
recv = self._wait_rpc_rsp()
return self._finalise_command(recv, rsp_decoder)
20 changes: 14 additions & 6 deletions src/infuse_iot/socket_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import socket
import struct
from contextlib import contextmanager
from typing import cast

from typing_extensions import Self
Expand Down Expand Up @@ -241,12 +242,12 @@ def connection_create(self, infuse_id: int, data_types: GatewayRequestConnection
req = GatewayRequestConnectionRequest(infuse_id, data_types)
self.send(req)
# Wait for response from the server
while rsp := self.receive():
if isinstance(rsp, ClientNotificationConnectionCreated):
return rsp.max_payload
elif isinstance(rsp, ClientNotificationConnectionFailed):
raise ConnectionRefusedError
raise ConnectionRefusedError
while True:
if rsp := self.receive():
if isinstance(rsp, ClientNotificationConnectionCreated):
return rsp.max_payload
elif isinstance(rsp, ClientNotificationConnectionFailed):
raise ConnectionRefusedError

def connection_release(self):
assert self._connection_id is not None
Expand All @@ -257,6 +258,13 @@ def connection_release(self):
self.send(req)
self._connection_id = None

@contextmanager
def connection(self, infuse_id: int, data_types: GatewayRequestConnectionRequest.DataType):
try:
yield self.connection_create(infuse_id, data_types)
finally:
self.connection_release()

def close(self):
# Cleanup any lingering connection context
if self._connection_id:
Expand Down
20 changes: 10 additions & 10 deletions src/infuse_iot/tools/bt_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ def add_parser(cls, parser):

def run(self):
try:
self._client.connection_create(self._id, GatewayRequestConnectionRequest.DataType.LOGGING)

while rsp := self._client.receive():
if isinstance(rsp, ClientNotificationConnectionDropped):
print(f"Connection to {self._id:016x} lost")
break
if isinstance(rsp, ClientNotificationEpacketReceived) and rsp.epacket.ptype == InfuseType.SERIAL_LOG:
print(rsp.epacket.payload.decode("utf-8"), end="")
with self._client.connection(self._id, GatewayRequestConnectionRequest.DataType.LOGGING) as _:
while rsp := self._client.receive():
if isinstance(rsp, ClientNotificationConnectionDropped):
print(f"Connection to {self._id:016x} lost")
break
if (
isinstance(rsp, ClientNotificationEpacketReceived)
and rsp.epacket.ptype == InfuseType.SERIAL_LOG
):
print(rsp.epacket.payload.decode("utf-8"), end="")

except KeyboardInterrupt:
print(f"Disconnecting from {self._id:016x}")
except ConnectionRefusedError:
print(f"Unable to connect to {self._id:016x}")
finally:
self._client.connection_release()
Loading
Loading