diff --git a/src/infuse_iot/common.py b/src/infuse_iot/common.py index 13fc517..5986620 100644 --- a/src/infuse_iot/common.py +++ b/src/infuse_iot/common.py @@ -3,7 +3,7 @@ import enum -class InfuseType(enum.Enum): +class InfuseType(enum.IntEnum): """Infuse Data Types""" ECHO_REQ = 0 @@ -15,7 +15,14 @@ class InfuseType(enum.Enum): RPC_RSP = 6 RECEIVED_EPACKET = 7 ACK = 8 + EPACKET_FORWARD = 9 SERIAL_LOG = 10 MEMFAULT_CHUNK = 30 KEY_IDS = 127 + + +class InfuseID(enum.IntEnum): + """Hardcoded Infuse IDs""" + + GATEWAY = -1 diff --git a/src/infuse_iot/database.py b/src/infuse_iot/database.py index bc5fe7f..18bc4d9 100644 --- a/src/infuse_iot/database.py +++ b/src/infuse_iot/database.py @@ -47,8 +47,15 @@ def __init__(self, address, network_id=None, device_id=None): self.address = address self.network_id = network_id self.device_id = device_id + self.bt_addr: InterfaceAddress.BluetoothLeAddr | None = None self.public_key = None self.shared_key = None + self._tx_gatt_seq = 0 + + def gatt_sequence_num(self): + """Persistent auto-incrementing sequence number for GATT""" + self._tx_gatt_seq += 1 + return self._tx_gatt_seq def __init__(self): self.gateway = None @@ -65,8 +72,6 @@ def observe_device( """Update device state based on observed packet""" if self.gateway is None: self.gateway = address - if bt_addr is not None: - self.bt_addr[bt_addr] = address if address not in self.devices: self.devices[address] = self.DeviceState(address) if network_id is not None: @@ -80,6 +85,9 @@ def observe_device( f"Device key for {address:016x} has changed" ) self.devices[address].device_id = device_id + if bt_addr is not None: + self.bt_addr[bt_addr] = address + self.devices[address].bt_addr = bt_addr def observe_security_state( self, address: int, cloud_key: bytes, device_key: bytes, network_id: int diff --git a/src/infuse_iot/epacket/interface.py b/src/infuse_iot/epacket/interface.py index 55d01b0..b76510f 100644 --- a/src/infuse_iot/epacket/interface.py +++ b/src/infuse_iot/epacket/interface.py @@ -65,6 +65,12 @@ def __str__(self) -> str: def len(self): return ctypes.sizeof(self.CtypesFormat) + def to_ctype(self) -> CtypesFormat: + """Convert the address to the ctype format""" + return self.CtypesFormat( + self.addr_type, bytes_to_uint8(self.addr_val.to_bytes(6, "little")) + ) + def to_json(self) -> Dict: return {"i": "BT", "t": self.addr_type, "v": self.addr_val} diff --git a/src/infuse_iot/epacket/packet.py b/src/infuse_iot/epacket/packet.py index 7436ab0..fc2932c 100644 --- a/src/infuse_iot/epacket/packet.py +++ b/src/infuse_iot/epacket/packet.py @@ -18,7 +18,7 @@ from infuse_iot.time import InfuseTime -class Auth(enum.Enum): +class Auth(enum.IntEnum): """Authorisation options""" DEVICE = 0 @@ -30,89 +30,6 @@ class Flags(enum.IntEnum): ENCR_NETWORK = 0x0000 -class InterfaceAddress(Serializable): - class SerialAddr(Serializable): - def __str__(self): - return "" - - def len(self): - return 0 - - def to_json(self) -> Dict: - return {"i": "SERIAL"} - - @classmethod - def from_json(cls, values: Dict) -> Self: - return cls() - - class BluetoothLeAddr(Serializable): - class CtypesFormat(ctypes.Structure): - _fields_ = [ - ("type", ctypes.c_uint8), - ("addr", 6 * ctypes.c_uint8), - ] - _pack_ = 1 - - def __init__(self, addr_type, addr_val): - self.addr_type = addr_type - self.addr_val = addr_val - - def __hash__(self) -> int: - return (self.addr_type << 48) + self.addr_val - - def __eq__(self, another) -> bool: - return ( - self.addr_type == another.addr_type - and self.addr_val == another.addr_val - ) - - def __str__(self) -> str: - t = "random" if self.addr_type == 1 else "public" - v = ":".join([f"{x:02x}" for x in self.addr_val.to_bytes(6, "big")]) - return f"{v} ({t})" - - def len(self): - return ctypes.sizeof(self.CtypesFormat) - - def to_json(self) -> Dict: - return {"i": "BT", "t": self.addr_type, "v": self.addr_val} - - @classmethod - def from_json(cls, values: Dict) -> Self: - return cls(values["t"], values["v"]) - - def __init__(self, val): - self.val = val - - def __str__(self): - return str(self.val) - - def len(self): - return self.val.len() - - def to_json(self) -> Dict: - return self.val.to_json() - - @classmethod - def from_json(cls, values: Dict) -> Self: - if values["i"] == "BT": - return cls(cls.BluetoothLeAddr.from_json(values)) - elif values["i"] == "SERIAL": - return cls(cls.SerialAddr()) - raise NotImplementedError("Unknown address type") - - @classmethod - def from_bytes(cls, interface: Interface, stream: bytes) -> Self: - assert interface in [ - Interface.BT_ADV, - Interface.BT_PERIPHERAL, - Interface.BT_CENTRAL, - ] - - c = cls.BluetoothLeAddr.CtypesFormat.from_buffer_copy(stream) - return cls.BluetoothLeAddr(c.type, int.from_bytes(bytes(c.addr), "little")) - - class HopOutput(Serializable): def __init__(self, infuse_id: int, interface: Interface, auth: Auth): self.infuse_id = infuse_id @@ -145,7 +62,7 @@ def __init__( self, infuse_id: int, interface: Interface, - interface_address: InterfaceAddress, + interface_address: Address, auth: Auth, key_identifier: int, gps_time: int, @@ -179,7 +96,7 @@ def from_json(cls, values: Dict) -> Self: return cls( infuse_id=values["id"], interface=interface, - interface_address=InterfaceAddress.from_json(values["interface_addr"]), + interface_address=Address.from_json(values["interface_addr"]), auth=Auth(values["auth"]), key_identifier=values["key_id"], gps_time=values["time"], @@ -242,7 +159,7 @@ def from_serial(cls, database: DeviceDatabase, serial_frame: bytes) -> List[Self frame_type = decode_mapping[common_header.interface] # Extract interface address (Only Bluetooth supported) - addr = InterfaceAddress.from_bytes(common_header.interface, packet_bytes) + addr = Address.from_bytes(common_header.interface, packet_bytes) del packet_bytes[: addr.len()] # Decrypting packet @@ -307,8 +224,8 @@ def from_serial(cls, database: DeviceDatabase, serial_frame: bytes) -> List[Self return packets -class PacketOutput(Serializable): - """ePacket to be transmitted by gateway""" +class PacketOutputRouted(Serializable): + """ePacket to be transmitted by gateway with complete route""" def __init__(self, route: List[HopOutput], ptype: InfuseType, payload: bytes): # [Serial, hop, hop, final_hop] @@ -319,23 +236,47 @@ def __init__(self, route: List[HopOutput], ptype: InfuseType, payload: bytes): def to_serial(self, database: DeviceDatabase) -> bytes: """Encode and encrypt packet for serial transmission""" gps_time = InfuseTime.gps_seconds_from_unix(int(time.time())) - # Multi hop not currently supported - assert len(self.route) == 1 - route = self.route[0] - if route.auth == Auth.NETWORK: + if len(self.route) == 2: + # Two hops only supports Bluetooth central for now + final = self.route[1] + assert final.interface == Interface.BT_CENTRAL + + # Forwarded payload + forward_payload = CtypeBtGattFrame.encrypt( + database, final.infuse_id, self.ptype, Auth.DEVICE, self.payload + ) + + # Forwarding header + forward_hdr = CtypeForwardHeaderBtGatt( + ctypes.sizeof(CtypeForwardHeaderBtGatt) + len(forward_payload), + Interface.BT_CENTRAL.value, + database.devices[final.infuse_id].bt_addr.to_ctype(), + ) + + ptype = InfuseType.EPACKET_FORWARD + payload = bytes(forward_hdr) + forward_payload + elif len(self.route) == 1: + ptype = self.ptype + payload = self.payload + else: + raise NotImplementedError(">2 hops currently not supported") + + serial = self.route[0] + + if serial.auth == Auth.NETWORK: flags = Flags.ENCR_NETWORK - key_metadata = database.devices[route.infuse_id].network_id - key = database.serial_network_key(route.infuse_id, gps_time) + key_metadata = database.devices[serial.infuse_id].network_id + key = database.serial_network_key(serial.infuse_id, gps_time) else: flags = Flags.ENCR_DEVICE - key_metadata = database.devices[route.infuse_id].device_id - key = database.serial_device_key(route.infuse_id, gps_time) + key_metadata = database.devices[serial.infuse_id].device_id + key = database.serial_device_key(serial.infuse_id, gps_time) # Create header header = CtypeSerialFrame( version=0, - _type=self.ptype.value, + _type=ptype, flags=flags, gps_time=gps_time, sequence=0, @@ -347,7 +288,7 @@ def to_serial(self, database: DeviceDatabase) -> bytes: # Encrypt and return payload header_bytes = bytes(header) ciphertext = chachapoly_encrypt( - key, header_bytes[:11], header_bytes[11:], self.payload + key, header_bytes[:11], header_bytes[11:], payload ) return header_bytes + ciphertext @@ -367,6 +308,42 @@ def from_json(cls, values: Dict) -> Self: ) +class PacketOutput(PacketOutputRouted): + """ePacket to be transmitted by gateway""" + + def __init__(self, infuse_id: int, auth: Auth, ptype: InfuseType, payload: bytes): + self.infuse_id = infuse_id + self.auth = auth + self.ptype = ptype + self.payload = payload + + def to_json(self) -> Dict: + return { + "infuse_id": self.infuse_id, + "auth": self.auth, + "type": self.ptype.value, + "payload": base64.b64encode(self.payload).decode("utf-8"), + } + + @classmethod + def from_json(cls, values: Dict) -> Self: + return cls( + infuse_id=values["infuse_id"], + auth=Auth(values["auth"]), + ptype=InfuseType(values["type"]), + payload=base64.b64decode(values["payload"].encode("utf-8")), + ) + + +class CtypeForwardHeaderBtGatt(ctypes.LittleEndianStructure): + _fields_ = [ + ("total_length", ctypes.c_uint16), + ("interface", ctypes.c_uint8), + ("address", Address.BluetoothLeAddr.CtypesFormat), + ] + _pack_ = 1 + + class CtypeV0VersionedFrame(ctypes.LittleEndianStructure): _fields_ = [ ("version", ctypes.c_uint8), @@ -419,7 +396,7 @@ def hop_received(self) -> HopReceived: return HopReceived( self.device_id, Interface.SERIAL, - InterfaceAddress(InterfaceAddress.SerialAddr()), + Address(Address.SerialAddr()), auth, self.key_metadata, self.gps_time, @@ -464,6 +441,44 @@ def decrypt( class CtypeBtGattFrame(CtypeV0VersionedFrame): """Bluetooth GATT packet header""" + @classmethod + def encrypt( + cls, + database: DeviceDatabase, + infuse_id: int, + ptype: InfuseType, + auth: Auth, + payload: bytes, + ) -> bytes: + dev_state = database.devices[infuse_id] + gps_time = InfuseTime.gps_seconds_from_unix(int(time.time())) + flags = 0 + + if auth == Auth.DEVICE: + key_meta = dev_state.device_id + key = database.bt_gatt_device_key(infuse_id, gps_time) + flags |= Flags.ENCR_DEVICE + else: + key_meta = dev_state.network_id + key = database.bt_gatt_network_key(infuse_id, gps_time) + + # Construct GATT header + header = cls() + header._type = ptype + header.flags = flags + header.device_id = infuse_id + header.key_metadata = key_meta + header.gps_time = gps_time + header.sequence = dev_state.gatt_sequence_num() + header.entropy = random.randint(0, 65535) + + # Encrypt and return payload + header_bytes = bytes(header) + ciphertext = chachapoly_encrypt( + key, header_bytes[:11], header_bytes[11:], payload + ) + return header_bytes + ciphertext + @classmethod def decrypt( cls, database: DeviceDatabase, bt_addr: Address.BluetoothLeAddr, frame: bytes diff --git a/src/infuse_iot/rpc_wrappers/time_get.py b/src/infuse_iot/rpc_wrappers/time_get.py index 7fe70b0..327250d 100644 --- a/src/infuse_iot/rpc_wrappers/time_get.py +++ b/src/infuse_iot/rpc_wrappers/time_get.py @@ -25,8 +25,13 @@ def handle_response(self, return_code, response): t_remote = InfuseTime.unix_time_from_epoch(response.epoch_time) t_local = time.time() + sync_age = ( + f"{response.sync_age} seconds ago" + if response.sync_age != 2**32 - 1 + else "Never" + ) print(f"\t Source: {InfuseTimeSource(response.time_source)}") print(f"\tRemote Time: {InfuseTime.utc_time_string(t_remote)}") print(f"\t Local Time: {InfuseTime.utc_time_string(t_local)}") - print(f"\t Synced: {response.sync_age} seconds ago") + print(f"\t Synced: {sync_age}") diff --git a/src/infuse_iot/socket_comms.py b/src/infuse_iot/socket_comms.py index cd65d26..d1ba448 100644 --- a/src/infuse_iot/socket_comms.py +++ b/src/infuse_iot/socket_comms.py @@ -3,6 +3,10 @@ import socket import struct import json +import enum + +from typing import Dict +from typing_extensions import Self from infuse_iot.epacket.packet import PacketReceived, PacketOutput @@ -11,6 +15,89 @@ def default_multicast_address(): return ("224.1.1.1", 8751) +class ClientNotification: + class Type(enum.IntEnum): + EPACKET_RECV = 0 + CONNECTION_FAILED = 1 + CONNECTION_CREATED = 2 + CONNECTION_DROPPED = 3 + + def __init__( + self, + notification_type: Type, + epacket: PacketReceived | None = None, + connection_id: int | None = None, + ): + self.type = notification_type + self.epacket = epacket + self.connection_id = connection_id + + def to_json(self) -> Dict: + """Convert class to json dictionary""" + out = {"type": int(self.type)} + if self.epacket: + out["epacket"] = self.epacket.to_json() + if self.connection_id: + out["connection_id"] = self.connection_id + return out + + @classmethod + def from_json(cls, values: Dict) -> Self: + """Reconstruct class from json dictionary""" + if j := values.get("epacket", None): + epacket = PacketReceived.from_json(j) + else: + epacket = None + connection_id = values.get("connection_id", None) + + return cls( + notification_type=cls.Type(values["type"]), + epacket=epacket, + connection_id=connection_id, + ) + + +class GatewayRequest: + class Type(enum.IntEnum): + EPACKET_SEND = 0 + CONNECTION_REQUEST = 1 + CONNECTION_RELEASE = 2 + + def __init__( + self, + notification_type: Type, + epacket: PacketOutput | None = None, + connection_id: int | None = None, + ): + self.type = notification_type + self.epacket = epacket + self.connection_id = connection_id + + def to_json(self) -> Dict: + """Convert class to json dictionary""" + out = {"type": int(self.type)} + if self.epacket: + out["epacket"] = self.epacket.to_json() + if self.connection_id: + out["connection_id"] = self.connection_id + return out + + @classmethod + def from_json(cls, values: Dict) -> Self: + """Reconstruct class from json dictionary""" + if j := values.get("epacket", None): + epacket = PacketOutput.from_json(j) + else: + epacket = None + connection_id = values.get("connection_id", None) + + return cls( + notification_type=cls.Type(values["type"]), + epacket=epacket, + connection_id=connection_id, + ) + + class LocalServer: def __init__(self, multicast_address): # Multicast output socket @@ -27,17 +114,17 @@ def __init__(self, multicast_address): self._input_sock.bind(unicast_address) self._input_sock.settimeout(0.2) - def broadcast(self, packet: PacketReceived): + def broadcast(self, notification: ClientNotification): self._output_sock.sendto( - json.dumps(packet.to_json()).encode("utf-8"), self._output_addr + json.dumps(notification.to_json()).encode("utf-8"), self._output_addr ) - def receive(self) -> PacketOutput | None: + def receive(self) -> GatewayRequest | None: try: data, _ = self._input_sock.recvfrom(8192) except TimeoutError: return None - return PacketOutput.from_json(json.loads(data.decode("utf-8"))) + return GatewayRequest.from_json(json.loads(data.decode("utf-8"))) def close(self): self._input_sock.close() @@ -62,21 +149,56 @@ def __init__(self, multicast_address, rx_timeout=0.2): self._output_sock = socket.socket( socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP ) + # Connection context + self._connection_id = None def set_rx_timeout(self, timeout): self._input_sock.settimeout(timeout) - def send(self, packet: PacketOutput): + def send(self, request: GatewayRequest): self._output_sock.sendto( - json.dumps(packet.to_json()).encode("utf-8"), self._output_addr + json.dumps(request.to_json()).encode("utf-8"), self._output_addr ) - def receive(self) -> PacketReceived | None: + def receive(self) -> ClientNotification | None: try: data, _ = self._input_sock.recvfrom(8192) except TimeoutError: return None - return PacketReceived.from_json(json.loads(data.decode("utf-8"))) + return ClientNotification.from_json(json.loads(data.decode("utf-8"))) + + def connection_create(self, infuse_id: int): + self._connection_id = infuse_id + + # Send the request for the connection + req = GatewayRequest( + GatewayRequest.Type.CONNECTION_REQUEST, connection_id=infuse_id + ) + self.send(req) + # Wait for response from the server + while rsp := self.receive(): + if rsp.connection_id == infuse_id: + if rsp.type == ClientNotification.Type.CONNECTION_CREATED: + break + elif rsp.type == ClientNotification.Type.CONNECTION_FAILED: + raise ConnectionRefusedError + raise NotImplementedError("Unexpected response") + + def connection_release(self): + req = GatewayRequest( + GatewayRequest.Type.CONNECTION_RELEASE, + connection_id=self._connection_id, + ) + self.send(req) + self._connection_id = None def close(self): + # Cleanup any lingering connection context + if self._connection_id: + req = GatewayRequest( + GatewayRequest.Type.CONNECTION_RELEASE, + connection_id=self._connection_id, + ) + self.send(req) + # Close the socket self._input_sock.close() diff --git a/src/infuse_iot/tools/gateway.py b/src/infuse_iot/tools/gateway.py index 1b87332..6001a3f 100644 --- a/src/infuse_iot/tools/gateway.py +++ b/src/infuse_iot/tools/gateway.py @@ -18,10 +18,15 @@ from infuse_iot.util.argparse import ValidFile from infuse_iot.util.console import Console -from infuse_iot.common import InfuseType +from infuse_iot.common import InfuseType, InfuseID from infuse_iot.commands import InfuseCommand from infuse_iot.serial_comms import RttPort, SerialPort, SerialFrame -from infuse_iot.socket_comms import LocalServer, default_multicast_address +from infuse_iot.socket_comms import ( + LocalServer, + ClientNotification, + GatewayRequest, + default_multicast_address, +) from infuse_iot.database import ( DeviceDatabase, NoKeyError, @@ -30,7 +35,7 @@ from infuse_iot.epacket.packet import ( Auth, PacketReceived, - PacketOutput, + PacketOutputRouted, HopOutput, ) import infuse_iot.epacket.interface as interface @@ -47,11 +52,11 @@ def __init__(self, database: DeviceDatabase): self._ddb = database self._queued = {} - def generate(self, command: int, args: bytes, cb): + def generate(self, command: int, args: bytes, auth: Auth, cb): """Generate RPC packet from arguments""" cmd_bytes = bytes(rpc.RequestHeader(self._cnt, command)) + args - cmd_pkt = PacketOutput( - [HopOutput.serial(Auth.NETWORK)], + cmd_pkt = PacketOutputRouted( + [HopOutput.serial(auth)], InfuseType.RPC_CMD, cmd_bytes, ) @@ -122,7 +127,9 @@ def security_state_done(pkt: PacketReceived, _: int, response: bytes): cb_event.set() # Generate security_state RPC - cmd_pkt = self.rpc.generate(30000, random.randbytes(16), security_state_done) + cmd_pkt = self.rpc.generate( + 30000, random.randbytes(16), Auth.NETWORK, security_state_done + ) encrypted = cmd_pkt.to_serial(self.ddb) # Write to serial port Console.log_tx(cmd_pkt.ptype, len(encrypted)) @@ -231,8 +238,12 @@ def _handle_serial_frame(self, frame: bytearray): # Proactively requery keys elif pkt.ptype == InfuseType.KEY_IDS: self._common.query_device_key(None) + + notification = ClientNotification( + ClientNotification.Type.EPACKET_RECV, epacket=pkt + ) # Forward to clients - self._common.server.broadcast(pkt) + self._common.server.broadcast(notification) except (ValueError, KeyError) as e: print(f"Decode failed ({e})") @@ -252,35 +263,133 @@ def send(self, pkt): """Queue packet for transmission""" self._queue.put(pkt) + def _handle_epacket_send(self, req: GatewayRequest): + if self._common.ddb.gateway is None: + Console.log_error("Gateway address unknown") + return + + pkt = req.epacket + + # Construct routed output + if pkt.infuse_id == InfuseID.GATEWAY: + routed = PacketOutputRouted( + [HopOutput(self._common.ddb.gateway, interface.ID.SERIAL, pkt.auth)], + pkt.ptype, + pkt.payload, + ) + else: + gateway = self._common.ddb.gateway + serial = HopOutput(gateway, interface.ID.SERIAL, Auth.DEVICE) + bt = HopOutput(pkt.infuse_id, interface.ID.BT_CENTRAL, pkt.auth) + routed = PacketOutputRouted( + [serial, bt], + pkt.ptype, + pkt.payload, + ) + + # Do we have the device public keys we need? + for hop in routed.route: + if hop.auth == Auth.DEVICE and not self._common.ddb.has_public_key( + hop.infuse_id + ): + cb_event = threading.Event() + self._common.query_device_key(cb_event) + + # Encode and encrypt payload + encrypted = routed.to_serial(self._common.ddb) + + # Write to serial port + Console.log_tx(routed.ptype, len(encrypted)) + self._common.port.write(encrypted) + + def _bt_connect_cb(self, pkt: PacketReceived, rc: int, response: bytes): + resp = defs.bt_connect_infuse.response.from_buffer_copy( + pkt.payload[ctypes.sizeof(rpc.ResponseHeader) :] + ) + if_addr = interface.Address.BluetoothLeAddr.from_rpc_struct(resp.peer) + infuse_id = self._common.ddb.infuse_id_from_bluetooth(if_addr) + + evt = ( + ClientNotification.Type.CONNECTION_FAILED + if rc < 0 + else ClientNotification.Type.CONNECTION_CREATED + ) + rsp = ClientNotification( + evt, + connection_id=infuse_id, + ) + self._common.server.broadcast(rsp) + + def _handle_conn_request(self, req: GatewayRequest): + if req.connection_id == InfuseID.GATEWAY: + # Local gateway always connected + rsp = ClientNotification( + ClientNotification.Type.CONNECTION_CREATED, + connection_id=req.connection_id, + ) + self._common.server.broadcast(rsp) + return + + state = self._common.ddb.devices.get(req.connection_id, None) + if state is None or state.bt_addr is None: + rsp = ClientNotification( + ClientNotification.Type.CONNECTION_FAILED, + connection_id=req.connection_id, + ) + self._common.server.broadcast(rsp) + return + + device_info = self._common.ddb.devices[req.connection_id] + + connect_args = defs.bt_connect_infuse.request( + device_info.bt_addr.to_rpc_struct(), + 10000, + defs.rpc_enum_infuse_bt_characteristic.COMMAND, + 0, + ) + cmd = self._common.rpc.generate( + defs.bt_connect_infuse.COMMAND_ID, + bytes(connect_args), + Auth.DEVICE, + self._bt_connect_cb, + ) + encrypted = cmd.to_serial(self._common.ddb) + Console.log_tx(cmd.ptype, len(encrypted)) + self._common.port.write(encrypted) + + def _handle_conn_release(self, req: GatewayRequest): + if req.connection_id == InfuseID.GATEWAY: + # Local gateway always connected + return + + state = self._common.ddb.devices.get(req.connection_id, None) + if state is None or state.bt_addr is None: + # Unknown device, nothing to do + return + + disconnect_args = defs.bt_disconnect.request(state.bt_addr.to_rpc_struct()) + cmd = self._common.rpc.generate( + defs.bt_disconnect.COMMAND_ID, bytes(disconnect_args) + ) + encrypted = cmd.to_serial(self._common.ddb) + Console.log_tx(cmd.ptype, len(encrypted)) + self._common.port.write(encrypted) + def _iter(self): if self._common.server is None: time.sleep(1.0) return # Loop while there are packets to send - while pkt := self._common.server.receive(): - if self._common.ddb.gateway is None: - Console.log_error("Gateway address unknown") - continue - - # Set gateway address - assert pkt.route[0].interface == interface.ID.SERIAL - pkt.route[0].infuse_id = self._common.ddb.gateway - - # Do we have the device public keys we need? - for hop in pkt.route: - if hop.auth == Auth.DEVICE and not self._common.ddb.has_public_key( - hop.infuse_id - ): - cb_event = threading.Event() - self._common.query_device_key(cb_event) - - # Encode and encrypt payload - encrypted = pkt.to_serial(self._common.ddb) - - # Write to serial port - Console.log_tx(pkt.ptype, len(encrypted)) - self._common.port.write(encrypted) + while req := self._common.server.receive(): + if req.type == GatewayRequest.Type.EPACKET_SEND: + self._handle_epacket_send(req) + elif req.type == GatewayRequest.Type.CONNECTION_REQUEST: + self._handle_conn_request(req) + elif req.type == GatewayRequest.Type.CONNECTION_RELEASE: + self._handle_conn_release(req) + else: + Console.log_error(f"Unhandled request {req.type}") class SubCommand(InfuseCommand): diff --git a/src/infuse_iot/tools/localhost.py b/src/infuse_iot/tools/localhost.py index 789e90c..b3d4488 100644 --- a/src/infuse_iot/tools/localhost.py +++ b/src/infuse_iot/tools/localhost.py @@ -16,7 +16,11 @@ from infuse_iot.util.console import Console from infuse_iot.common import InfuseType from infuse_iot.commands import InfuseCommand -from infuse_iot.socket_comms import LocalClient, default_multicast_address +from infuse_iot.socket_comms import ( + LocalClient, + ClientNotification, + default_multicast_address, +) from infuse_iot.tdf import TDF from infuse_iot.time import InfuseTime import infuse_iot.epacket.interface as interface @@ -154,10 +158,12 @@ def recv_thread(self): break if msg is None: continue - if msg.ptype != InfuseType.TDF: + if msg.type != ClientNotification.Type.EPACKET_RECV: + continue + if msg.epacket.ptype != InfuseType.TDF: continue - source = msg.route[0] + source = msg.epacket.route[0] self._data_lock.acquire(blocking=True) @@ -174,7 +180,7 @@ def recv_thread(self): self._data[source.infuse_id]["bt_addr"] = addr_str self._data[source.infuse_id]["bt_rssi"] = source.rssi - for tdf in self._decoder.decode(msg.payload): + for tdf in self._decoder.decode(msg.epacket.payload): t = tdf.data[-1] if t.name not in self._columns: self._columns[t.name] = self.tdf_columns(t) diff --git a/src/infuse_iot/tools/native_bt.py b/src/infuse_iot/tools/native_bt.py index c748824..2af0a4a 100644 --- a/src/infuse_iot/tools/native_bt.py +++ b/src/infuse_iot/tools/native_bt.py @@ -21,7 +21,11 @@ Flags, ) from infuse_iot.commands import InfuseCommand -from infuse_iot.socket_comms import LocalServer, default_multicast_address +from infuse_iot.socket_comms import ( + LocalServer, + ClientNotification, + default_multicast_address, +) from infuse_iot.database import DeviceDatabase import infuse_iot.epacket.interface as interface @@ -62,7 +66,8 @@ def simple_callback(self, device: BLEDevice, data: AdvertisementData): Console.log_rx(hdr.type, len(payload)) pkt = PacketReceived([hop], hdr.type, decr) - self.server.broadcast(pkt) + notification = ClientNotification(ClientNotification.Type.EPACKET_RECV, pkt) + self.server.broadcast(notification) async def async_run(self): self.server = LocalServer(default_multicast_address()) diff --git a/src/infuse_iot/tools/rpc.py b/src/infuse_iot/tools/rpc.py index 04585eb..c5cf386 100644 --- a/src/infuse_iot/tools/rpc.py +++ b/src/infuse_iot/tools/rpc.py @@ -11,10 +11,15 @@ import importlib import pkgutil -from infuse_iot.common import InfuseType -from infuse_iot.epacket.packet import PacketOutput, HopOutput +from infuse_iot.common import InfuseType, InfuseID +from infuse_iot.epacket.packet import PacketOutput from infuse_iot.commands import InfuseCommand, InfuseRpcCommand -from infuse_iot.socket_comms import LocalClient, default_multicast_address +from infuse_iot.socket_comms import ( + LocalClient, + ClientNotification, + GatewayRequest, + default_multicast_address, +) from infuse_iot import rpc import infuse_iot.rpc_wrappers as wrappers @@ -27,6 +32,13 @@ class SubCommand(InfuseCommand): @classmethod def add_parser(cls, parser): + addr_group = parser.add_mutually_exclusive_group(required=True) + addr_group.add_argument( + "--gateway", action="store_true", help="Run command on local gateway" + ) + addr_group.add_argument( + "--id", type=lambda x: int(x, 0), help="Infuse ID to run command on" + ) command_list_parser = parser.add_subparsers( title="commands", metavar="", required=True ) @@ -51,12 +63,18 @@ def __init__(self, args): self._client = LocalClient(default_multicast_address(), 10.0) self._command: InfuseRpcCommand = args.rpc_class(args) self._request_id = random.randint(0, 2**32 - 1) + if args.gateway: + self._id = InfuseID.GATEWAY + else: + self._id = args.id def _wait_data_ack(self): while rsp := self._client.receive(): - if rsp.ptype != InfuseType.RPC_DATA_ACK: + if rsp.type != ClientNotification.Type.EPACKET_RECV: continue - data_ack = rpc.DataAck.from_buffer_copy(rsp.payload) + if rsp.epacket.ptype != InfuseType.RPC_DATA_ACK: + continue + data_ack = rpc.DataAck.from_buffer_copy(rsp.epacket.payload) # Response to the request we sent if data_ack.request_id != self._request_id: continue @@ -65,19 +83,21 @@ def _wait_data_ack(self): def _wait_rpc_rsp(self): # Wait for responses while rsp := self._client.receive(): + if rsp.type != ClientNotification.Type.EPACKET_RECV: + continue # RPC response packet - if rsp.ptype != InfuseType.RPC_RSP: + if rsp.epacket.ptype != InfuseType.RPC_RSP: continue - rsp_header = rpc.ResponseHeader.from_buffer_copy(rsp.payload) + rsp_header = rpc.ResponseHeader.from_buffer_copy(rsp.epacket.payload) # Response to the request we sent if rsp_header.request_id != self._request_id: continue # Convert response bytes back to struct form rsp_data = self._command.response.from_buffer_copy( - rsp.payload[ctypes.sizeof(rpc.ResponseHeader) :] + rsp.epacket.payload[ctypes.sizeof(rpc.ResponseHeader) :] ) # Handle the response - print(f"INFUSE ID: {rsp.route[0].infuse_id:016x}") + print(f"INFUSE ID: {rsp.epacket.route[0].infuse_id:016x}") self._command.handle_response(rsp_header.return_code, rsp_data) break @@ -90,11 +110,13 @@ def _run_data_cmd(self): request_packet = bytes(header) + bytes(data_hdr) + bytes(params) pkt = PacketOutput( - [HopOutput.serial(self._command.auth_level())], + self._id, + self._command.auth_level(), InfuseType.RPC_CMD, request_packet, ) - self._client.send(pkt) + req = GatewayRequest(GatewayRequest.Type.EPACKET_SEND, epacket=pkt) + self._client.send(req) # Wait for initial ACK self._wait_data_ack() @@ -110,7 +132,8 @@ def _run_data_cmd(self): hdr = rpc.DataHeader(self._request_id, offset) pkt_bytes = bytes(hdr) + payload pkt = PacketOutput( - [HopOutput.serial(self._command.auth_level())], + self._id, + self._command.auth_level(), InfuseType.RPC_DATA, pkt_bytes, ) @@ -134,15 +157,21 @@ def _run_standard_cmd(self): request_packet = bytes(header) + bytes(params) pkt = PacketOutput( - [HopOutput.serial(self._command.auth_level())], + self._id, + self._command.auth_level(), InfuseType.RPC_CMD, request_packet, ) - self._client.send(pkt) + req = GatewayRequest(GatewayRequest.Type.EPACKET_SEND, epacket=pkt) + self._client.send(req) self._wait_rpc_rsp() def run(self): - if self._command.RPC_DATA: - self._run_data_cmd() - else: - self._run_standard_cmd() + try: + self._client.connection_create(self._id) + if self._command.RPC_DATA: + self._run_data_cmd() + else: + self._run_standard_cmd() + finally: + self._client.connection_release() diff --git a/src/infuse_iot/tools/serial_throughput.py b/src/infuse_iot/tools/serial_throughput.py index 05f3f27..1ced257 100644 --- a/src/infuse_iot/tools/serial_throughput.py +++ b/src/infuse_iot/tools/serial_throughput.py @@ -8,10 +8,15 @@ import random import time -from infuse_iot.common import InfuseType -from infuse_iot.epacket.packet import PacketOutput, HopOutput +from infuse_iot.common import InfuseType, InfuseID +from infuse_iot.epacket.packet import PacketOutput, Auth from infuse_iot.commands import InfuseCommand -from infuse_iot.socket_comms import LocalClient, default_multicast_address +from infuse_iot.socket_comms import ( + LocalClient, + ClientNotification, + GatewayRequest, + default_multicast_address, +) class SubCommand(InfuseCommand): @@ -45,16 +50,20 @@ def run_send_test(self, num, size, queue_size): while (sent != num) and (pending < queue_size): payload = sent.to_bytes(4, "little") + random.randbytes(size - 4) pkt = PacketOutput( - [HopOutput.serial()], + InfuseID.GATEWAY, + Auth.DEVICE, InfuseType.ECHO_REQ, payload, ) - self._client.send(pkt) + req = GatewayRequest(GatewayRequest.Type.EPACKET_SEND, epacket=pkt) + self._client.send(req) sent += 1 pending += 1 # Wait for responses if rsp := self._client.receive(): - if rsp.ptype != InfuseType.ECHO_RSP: + if rsp.type != ClientNotification.Type.EPACKET_RECV: + continue + if rsp.epacket.ptype != InfuseType.ECHO_RSP: continue responses += 1 pending -= 1 diff --git a/src/infuse_iot/tools/tdf_csv.py b/src/infuse_iot/tools/tdf_csv.py index 1104e7c..d8181f4 100644 --- a/src/infuse_iot/tools/tdf_csv.py +++ b/src/infuse_iot/tools/tdf_csv.py @@ -10,7 +10,11 @@ from infuse_iot.common import InfuseType from infuse_iot.commands import InfuseCommand -from infuse_iot.socket_comms import LocalClient, default_multicast_address +from infuse_iot.socket_comms import ( + LocalClient, + ClientNotification, + default_multicast_address, +) from infuse_iot.tdf import TDF from infuse_iot.time import InfuseTime @@ -38,11 +42,13 @@ def run(self): msg = self._client.receive() if msg is None: continue - if msg.ptype != InfuseType.TDF: + if msg.type != ClientNotification.Type.EPACKET_RECV: continue - source = msg.route[0] + if msg.epacket.ptype != InfuseType.TDF: + continue + source = msg.epacket.route[0] - for tdf in self._decoder.decode(msg.payload): + for tdf in self._decoder.decode(msg.epacket.payload): # Construct reading strings lines = [] reading_time = tdf.time diff --git a/src/infuse_iot/tools/tdf_list.py b/src/infuse_iot/tools/tdf_list.py index 9bf92f4..28f9b1c 100644 --- a/src/infuse_iot/tools/tdf_list.py +++ b/src/infuse_iot/tools/tdf_list.py @@ -9,7 +9,11 @@ from infuse_iot.common import InfuseType from infuse_iot.commands import InfuseCommand -from infuse_iot.socket_comms import LocalClient, default_multicast_address +from infuse_iot.socket_comms import ( + LocalClient, + ClientNotification, + default_multicast_address, +) from infuse_iot.tdf import TDF from infuse_iot.time import InfuseTime @@ -28,13 +32,15 @@ def run(self): msg = self._client.receive() if msg is None: continue - if msg.ptype != InfuseType.TDF: + if msg.type != ClientNotification.Type.EPACKET_RECV: continue - source = msg.route[0] + if msg.epacket.ptype != InfuseType.TDF: + continue + source = msg.epacket.route[0] table = [] - for tdf in self._decoder.decode(msg.payload): + for tdf in self._decoder.decode(msg.epacket.payload): t = tdf.data[-1] num = len(tdf.data) if num > 1: