diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..ca7c5a7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,25 @@ +[mypy] + +[mypy-colorama] +ignore_missing_imports = True + +[mypy-serial] +ignore_missing_imports = True + +[mypy-pylink] +ignore_missing_imports = True + +[mypy-pynrfjprog] +ignore_missing_imports = True + +[mypy-simple_term_menu] +ignore_missing_imports = True + +[mypy-dateutil.*] +ignore_missing_imports = True + +[mypy-plotly.*] +ignore_missing_imports = True + +[mypy-dash.*] +ignore_missing_imports = True diff --git a/src/infuse_iot/commands.py b/src/infuse_iot/commands.py index ebc07ff..7828460 100644 --- a/src/infuse_iot/commands.py +++ b/src/infuse_iot/commands.py @@ -8,17 +8,24 @@ import argparse import ctypes +from typing import List, Type, Tuple + + from infuse_iot.epacket.packet import Auth class InfuseCommand: """Infuse-IoT SDK meta-tool command parent class""" + NAME = "N/A" + HELP = "N/A" + DESCRIPTION = "N/A" + @classmethod def add_parser(cls, parser: argparse.ArgumentParser): """Add arguments for sub-command""" - def __init__(self, **kwargs): + def __init__(self, args: argparse.Namespace): pass def run(self): @@ -56,9 +63,9 @@ def handle_response(self, return_code, response): raise NotImplementedError class VariableSizeResponse: - base_fields = [] + base_fields: List[Tuple[str, Type[ctypes._SimpleCData]]] = [] var_name = "x" - var_type = ctypes.c_ubyte + var_type: Type[ctypes._SimpleCData] = ctypes.c_ubyte @classmethod def from_buffer_copy(cls, source, offset=0): diff --git a/src/infuse_iot/database.py b/src/infuse_iot/database.py index 18bc4d9..d3670f2 100644 --- a/src/infuse_iot/database.py +++ b/src/infuse_iot/database.py @@ -2,7 +2,7 @@ import binascii import base64 -from typing import Dict +from typing import Dict, Tuple from infuse_iot.api_client import Client from infuse_iot.api_client.api.default import get_shared_secret @@ -38,18 +38,23 @@ class DeviceDatabase: _network_keys = { 0x000000: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f", } - _derived_keys = {} + _derived_keys: Dict[Tuple[int, bytes, int], bytes] = {} class DeviceState: """Device State""" - def __init__(self, address, network_id=None, device_id=None): + def __init__( + self, + address: int, + network_id: int | None = None, + device_id: int | None = None, + ): self.address = address self.network_id = network_id self.device_id = device_id self.bt_addr: InterfaceAddress.BluetoothLeAddr | None = None - self.public_key = None - self.shared_key = None + self.public_key: bytes | None = None + self.shared_key: bytes | None = None self._tx_gatt_seq = 0 def gatt_sequence_num(self): @@ -57,8 +62,8 @@ def gatt_sequence_num(self): self._tx_gatt_seq += 1 return self._tx_gatt_seq - def __init__(self): - self.gateway = None + def __init__(self) -> None: + self.gateway: int | None = None self.devices: Dict[int, DeviceDatabase.DeviceState] = {} self.bt_addr: Dict[InterfaceAddress.BluetoothLeAddr, int] = {} @@ -68,7 +73,7 @@ def observe_device( network_id: int | None = None, device_id: int | None = None, bt_addr: InterfaceAddress.BluetoothLeAddr | None = None, - ): + ) -> None: """Update device state based on observed packet""" if self.gateway is None: self.gateway = address @@ -91,7 +96,7 @@ def observe_device( def observe_security_state( self, address: int, cloud_key: bytes, device_key: bytes, network_id: int - ): + ) -> None: """Update device state based on security_state response""" if address not in self.devices: self.devices[address] = self.DeviceState(address) @@ -107,10 +112,11 @@ def observe_security_state( with client as client: body = Key(base64.b64encode(device_key).decode("utf-8")) response = get_shared_secret.sync(client=client, body=body) - key = base64.b64decode(response.key) - self.devices[address].shared_key = key + if response is not None: + key = base64.b64decode(response.key) + self.devices[address].shared_key = key - def _network_key(self, network_id: int, interface: str, gps_time: int): + def _network_key(self, network_id: int, interface: bytes, gps_time: int) -> bytes: if network_id not in self._network_keys: try: info = load_network(network_id) @@ -128,22 +134,22 @@ def _network_key(self, network_id: int, interface: str, gps_time: int): return self._derived_keys[key_id] - def _serial_key(self, base, time_idx): + def _serial_key(self, base: bytes, time_idx: int) -> bytes: return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"serial") - def _bt_adv_key(self, base, time_idx): + def _bt_adv_key(self, base: bytes, time_idx: int) -> bytes: return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"bt_adv") - def _bt_gatt_key(self, base, time_idx): + def _bt_gatt_key(self, base: bytes, time_idx: int) -> bytes: return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"bt_gatt") - def has_public_key(self, address: int): + def has_public_key(self, address: int) -> bool: """Does the database have the public key for this device?""" if address not in self.devices: return False return self.devices[address].public_key is not None - def has_network_id(self, address: int): + def has_network_id(self, address: int) -> bool: """Does the database know the network ID for this device?""" if address not in self.devices: return False @@ -155,15 +161,17 @@ def infuse_id_from_bluetooth( """Get Bluetooth address associated with device""" return self.bt_addr.get(bt_addr, None) - def serial_network_key(self, address: int, gps_time: int): + def serial_network_key(self, address: int, gps_time: int) -> bytes: """Network key for serial interface""" if address not in self.devices: raise DeviceUnknownNetworkKey network_id = self.devices[address].network_id + if network_id is None: + raise DeviceUnknownNetworkKey return self._network_key(network_id, b"serial", gps_time) - def serial_device_key(self, address: int, gps_time: int): + def serial_device_key(self, address: int, gps_time: int) -> bytes: """Device key for serial interface""" if address not in self.devices: raise DeviceUnknownDeviceKey @@ -177,15 +185,17 @@ def serial_device_key(self, address: int, gps_time: int): return self._serial_key(base, time_idx) - def bt_adv_network_key(self, address: int, gps_time: int): + def bt_adv_network_key(self, address: int, gps_time: int) -> bytes: """Network key for Bluetooth advertising interface""" if address not in self.devices: raise DeviceUnknownNetworkKey network_id = self.devices[address].network_id + if network_id is None: + raise DeviceUnknownNetworkKey return self._network_key(network_id, b"bt_adv", gps_time) - def bt_adv_device_key(self, address: int, gps_time: int): + def bt_adv_device_key(self, address: int, gps_time: int) -> bytes: """Device key for Bluetooth advertising interface""" if address not in self.devices: raise DeviceUnknownDeviceKey @@ -199,15 +209,17 @@ def bt_adv_device_key(self, address: int, gps_time: int): return self._bt_adv_key(base, time_idx) - def bt_gatt_network_key(self, address: int, gps_time: int): + def bt_gatt_network_key(self, address: int, gps_time: int) -> bytes: """Network key for Bluetooth advertising interface""" if address not in self.devices: raise DeviceUnknownNetworkKey network_id = self.devices[address].network_id + if network_id is None: + raise DeviceUnknownNetworkKey return self._network_key(network_id, b"bt_gatt", gps_time) - def bt_gatt_device_key(self, address: int, gps_time: int): + def bt_gatt_device_key(self, address: int, gps_time: int) -> bytes: """Device key for Bluetooth advertising interface""" if address not in self.devices: raise DeviceUnknownDeviceKey diff --git a/src/infuse_iot/diff.py b/src/infuse_iot/diff.py index 815c9ad..29898e4 100644 --- a/src/infuse_iot/diff.py +++ b/src/infuse_iot/diff.py @@ -6,7 +6,8 @@ import binascii from collections import defaultdict -from typing import List, Dict +from typing import List, Dict, Tuple, Type +from typing_extensions import Self class ValidationError(Exception): @@ -131,20 +132,25 @@ def ctypes_class(self): return self.SetAddrU32 @classmethod - def from_bytes(cls, b: bytes, offset: int, original_offset: int): + def from_bytes( + cls, b: bytes, offset: int, original_offset: int + ) -> Tuple[Self, int, int]: opcode = b[offset] if opcode == OpCode.ADDR_SHIFT_S8: - s = cls.ShiftAddrS8.from_buffer_copy(b, offset) - c = cls(original_offset, original_offset + s.val) + s8 = cls.ShiftAddrS8.from_buffer_copy(b, offset) + c = cls(original_offset, original_offset + s8.val) + struct_len = ctypes.sizeof(s8) elif opcode == OpCode.ADDR_SHIFT_S16: - s = cls.ShiftAddrS16.from_buffer_copy(b, offset) - c = cls(original_offset, original_offset + s.val) + s16 = cls.ShiftAddrS16.from_buffer_copy(b, offset) + c = cls(original_offset, original_offset + s16.val) + struct_len = ctypes.sizeof(s16) elif opcode == OpCode.ADDR_SET_U32: - s = cls.SetAddrU32.from_buffer_copy(b, offset) - c = cls(original_offset, s.val) + s32 = cls.SetAddrU32.from_buffer_copy(b, offset) + c = cls(original_offset, s32.val) + struct_len = ctypes.sizeof(s32) else: raise RuntimeError - return c, ctypes.sizeof(s), c.new + return c, struct_len, c.new def __bytes__(self): instr = self.ctypes_class() @@ -167,7 +173,10 @@ def __str__(self): class CopyInstr(Instr): - class CopyU4(ctypes.LittleEndianStructure): + class CopyGeneric(ctypes.LittleEndianStructure): + pass + + class CopyU4(CopyGeneric): op = OpCode.COPY_LEN_U4 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -178,7 +187,7 @@ class CopyU4(ctypes.LittleEndianStructure): def length(self): return OpCode.data(self.opcode) - class CopyU12(ctypes.LittleEndianStructure): + class CopyU12(CopyGeneric): op = OpCode.COPY_LEN_U12 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -190,7 +199,7 @@ class CopyU12(ctypes.LittleEndianStructure): def length(self): return (OpCode.data(self.opcode) << 8) | self._length - class CopyU20(ctypes.LittleEndianStructure): + class CopyU20(CopyGeneric): op = OpCode.COPY_LEN_U20 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -202,7 +211,7 @@ class CopyU20(ctypes.LittleEndianStructure): def length(self): return (OpCode.data(self.opcode) << 16) | self._length - class CopyU32(ctypes.LittleEndianStructure): + class CopyU32(CopyGeneric): op = OpCode.COPY_LEN_U32 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -230,19 +239,25 @@ def ctypes_class(self): return self.CopyU32 @classmethod - def from_bytes(cls, b: bytes, offset: int, original_offset: int): - opcode = OpCode.from_byte(b[offset]) - if opcode == OpCode.COPY_LEN_U4: - s = cls.CopyU4.from_buffer_copy(b, offset) - elif opcode == OpCode.COPY_LEN_U12: - s = cls.CopyU12.from_buffer_copy(b, offset) - elif opcode == OpCode.COPY_LEN_U20: - s = cls.CopyU20.from_buffer_copy(b, offset) - elif opcode == OpCode.COPY_LEN_U32: - s = cls.CopyU32.from_buffer_copy(b, offset) + def _from_opcode(cls, op: OpCode) -> Type[CopyGeneric]: + if op == OpCode.COPY_LEN_U4: + return cls.CopyU4 + elif op == OpCode.COPY_LEN_U12: + return cls.CopyU12 + elif op == OpCode.COPY_LEN_U20: + return cls.CopyU20 + elif op == OpCode.COPY_LEN_U32: + return cls.CopyU32 else: raise RuntimeError + @classmethod + def from_bytes( + cls, b: bytes, offset: int, original_offset: int + ) -> Tuple[Self, int, int]: + opcode = OpCode.from_byte(b[offset]) + op_class = cls._from_opcode(opcode) + s = op_class.from_buffer_copy(b, offset) return cls(s.length), ctypes.sizeof(s), original_offset + s.length def __bytes__(self): @@ -265,7 +280,10 @@ def __str__(self): class WriteInstr(Instr): - class WriteU4(ctypes.LittleEndianStructure): + class WriteGeneric(ctypes.LittleEndianStructure): + pass + + class WriteU4(WriteGeneric): op = OpCode.WRITE_LEN_U4 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -276,7 +294,7 @@ class WriteU4(ctypes.LittleEndianStructure): def length(self): return OpCode.data(self.opcode) - class WriteU12(ctypes.LittleEndianStructure): + class WriteU12(WriteGeneric): op = OpCode.WRITE_LEN_U12 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -288,7 +306,7 @@ class WriteU12(ctypes.LittleEndianStructure): def length(self): return (OpCode.data(self.opcode) << 8) | self._length - class WriteU20(ctypes.LittleEndianStructure): + class WriteU20(WriteGeneric): op = OpCode.WRITE_LEN_U20 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -300,7 +318,7 @@ class WriteU20(ctypes.LittleEndianStructure): def length(self): return (OpCode.data(self.opcode) << 16) | self._length - class WriteU32(ctypes.LittleEndianStructure): + class WriteU32(WriteGeneric): op = OpCode.WRITE_LEN_U32 _fields_ = [ ("opcode", ctypes.c_uint8), @@ -326,18 +344,25 @@ def ctypes_class(self): return self.WriteU32 @classmethod - def from_bytes(cls, b: bytes, offset: int, original_offset: int): - opcode = OpCode.from_byte(b[offset]) - if opcode == OpCode.WRITE_LEN_U4: - s = cls.WriteU4.from_buffer_copy(b, offset) - elif opcode == OpCode.WRITE_LEN_U12: - s = cls.WriteU12.from_buffer_copy(b, offset) - elif opcode == OpCode.WRITE_LEN_U20: - s = cls.WriteU20.from_buffer_copy(b, offset) - elif opcode == OpCode.WRITE_LEN_U32: - s = cls.WriteU32.from_buffer_copy(b, offset) + def _from_opcode(cls, op: OpCode) -> Type[WriteGeneric]: + if op == OpCode.WRITE_LEN_U4: + return cls.WriteU4 + elif op == OpCode.WRITE_LEN_U12: + return cls.WriteU12 + elif op == OpCode.WRITE_LEN_U20: + return cls.WriteU20 + elif op == OpCode.WRITE_LEN_U32: + return cls.WriteU32 else: raise RuntimeError + + @classmethod + def from_bytes( + cls, b: bytes, offset: int, original_offset: int + ) -> Tuple[Self, int, int]: + opcode = OpCode.from_byte(b[offset]) + op_class = cls._from_opcode(opcode) + s = op_class.from_buffer_copy(b, offset) hdr_len = ctypes.sizeof(s) return ( @@ -386,7 +411,7 @@ def ctypes_class(self): @classmethod def from_bytes(cls, b: bytes, offset: int, original_offset: int): assert b[offset] == OpCode.PATCH - operations = [] + operations: List[Instr] = [] length = 1 while True: @@ -489,7 +514,7 @@ class ArrayValidation(ctypes.LittleEndianStructure): @classmethod def _naive_diff(cls, old: bytes, new: bytes, hash_len: int = 8): """Construct basic runs of WRITE, COPY, and SET_ADDR instructions""" - instr = [] + instr: List[Instr] = [] old_offset = 0 new_offset = 0 write_start = 0 @@ -628,7 +653,7 @@ def _cleanup_jumps(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: def _merge_operations(cls, instructions: List[Instr]) -> List[Instr]: """Merge runs of COPY and WRITE into PATCH""" merged: List[Instr] = [] - to_merge = [] + to_merge: List[Instr] = [] def finalise(): nonlocal merged @@ -664,7 +689,7 @@ def finalise(): def _write_crack(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: """Crack a WRITE operation into a [WRITE,COPY,WRITE] if COPY is at least 2 bytes""" - cracked = [] + cracked: List[Instr] = [] old_offset = 0 while len(instructions): @@ -741,8 +766,10 @@ def _write_crack(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: return cracked @classmethod - def _gen_patch_instr(cls, bin_orig: bytes, bin_new: bytes) -> List[Instr]: - best_patch = None + def _gen_patch_instr( + cls, bin_orig: bytes, bin_new: bytes + ) -> Tuple[Dict, List[Instr]]: + best_patch = [] best_patch_len = 2**32 # Find best diff across range @@ -867,7 +894,7 @@ def generate( ) if verbose: - class_count = defaultdict(int) + class_count: Dict[OpCode, int] = defaultdict(int) for instr in instructions: class_count[instr.ctypes_class().op] += 1 @@ -890,7 +917,7 @@ def validation( assert len(bin_original) > 1024 # Manually construct an instruction set that runs all instructions - instructions = [] + instructions: List[Instr] = [] instructions.append( WriteInstr(bin_original[:8], cls_override=WriteInstr.WriteU4) ) @@ -1010,7 +1037,7 @@ def dump( f" Patch File: {len(bin_patch)} bytes ({len(instructions):5d} instructions)" ) - class_count = defaultdict(int) + class_count: Dict[OpCode, int] = defaultdict(int) for instr in instructions: class_count[instr.ctypes_class().op] += 1 if isinstance(instr, WriteInstr): @@ -1028,8 +1055,8 @@ def dump( print("") print("Instruction Count:") - for cls, count in sorted(class_count.items()): - print(f"{cls.name:>16s}: {count}") + for op_cls, count in sorted(class_count.items()): + print(f"{op_cls.name:>16s}: {count}") print("") print("Instruction List:") diff --git a/src/infuse_iot/epacket/interface.py b/src/infuse_iot/epacket/interface.py index b76510f..8f4f396 100644 --- a/src/infuse_iot/epacket/interface.py +++ b/src/infuse_iot/epacket/interface.py @@ -120,4 +120,4 @@ def from_bytes(cls, interface: ID, stream: bytes) -> Self: ] c = cls.BluetoothLeAddr.CtypesFormat.from_buffer_copy(stream) - return cls.BluetoothLeAddr(c.type, int.from_bytes(bytes(c.addr), "little")) + return cls(cls.BluetoothLeAddr(c.type, int.from_bytes(bytes(c.addr), "little"))) diff --git a/src/infuse_iot/epacket/packet.py b/src/infuse_iot/epacket/packet.py index fc2932c..8be1a6f 100644 --- a/src/infuse_iot/epacket/packet.py +++ b/src/infuse_iot/epacket/packet.py @@ -6,7 +6,7 @@ import time import random -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Any from typing_extensions import Self from infuse_iot.common import InfuseType @@ -149,7 +149,7 @@ def from_serial(cls, database: DeviceDatabase, serial_frame: bytes) -> List[Self del packet_bytes[: ctypes.sizeof(common_header)] # Only Bluetooth advertising supported for now - decode_mapping = { + decode_mapping: Dict[Interface, Any] = { Interface.BT_ADV: CtypeBtAdvFrame, Interface.BT_PERIPHERAL: CtypeBtGattFrame, Interface.BT_CENTRAL: CtypeBtGattFrame, @@ -166,7 +166,7 @@ def from_serial(cls, database: DeviceDatabase, serial_frame: bytes) -> List[Self if common_header.encrypted: try: f_header, f_decrypted = frame_type.decrypt( - database, addr, packet_bytes + database, addr.val, packet_bytes ) except NoKeyError: continue @@ -198,7 +198,7 @@ def from_serial(cls, database: DeviceDatabase, serial_frame: bytes) -> List[Self del packet_bytes[: ctypes.sizeof(decr_header)] # Notify database of BT Addr -> Infuse ID mapping - database.observe_device(decr_header.device_id, bt_addr=addr) + database.observe_device(decr_header.device_id, bt_addr=addr.val) bt_hop = HopReceived( decr_header.device_id, @@ -240,7 +240,9 @@ def to_serial(self, database: DeviceDatabase) -> bytes: if len(self.route) == 2: # Two hops only supports Bluetooth central for now final = self.route[1] + bt_addr = database.devices[final.infuse_id].bt_addr assert final.interface == Interface.BT_CENTRAL + assert bt_addr is not None # Forwarded payload forward_payload = CtypeBtGattFrame.encrypt( @@ -251,7 +253,7 @@ def to_serial(self, database: DeviceDatabase) -> bytes: forward_hdr = CtypeForwardHeaderBtGatt( ctypes.sizeof(CtypeForwardHeaderBtGatt) + len(forward_payload), Interface.BT_CENTRAL.value, - database.devices[final.infuse_id].bt_addr.to_ctype(), + bt_addr.to_ctype(), ) ptype = InfuseType.EPACKET_FORWARD @@ -273,6 +275,10 @@ def to_serial(self, database: DeviceDatabase) -> bytes: key_metadata = database.devices[serial.infuse_id].device_id key = database.serial_device_key(serial.infuse_id, gps_time) + # Validation + assert key_metadata is not None + assert database.gateway is not None + # Create header header = CtypeSerialFrame( version=0, @@ -381,9 +387,9 @@ def device_id(self, value): @classmethod def parse(cls, frame: bytes) -> Tuple[Self, int]: - """Parse serial frame into header and payload length""" + """Parse frame into header and payload length""" return ( - CtypeV0VersionedFrame.from_buffer_copy(frame), + cls.from_buffer_copy(frame), len(frame) - ctypes.sizeof(CtypeV0VersionedFrame) - 16, ) @@ -462,15 +468,19 @@ def encrypt( key_meta = dev_state.network_id key = database.bt_gatt_network_key(infuse_id, gps_time) + # Validate + assert key_meta is not None + # Construct GATT header - header = cls() - header._type = ptype - header.flags = flags + header = cls( + _type=ptype, + flags=flags, + gps_time=gps_time, + sequence=dev_state.gatt_sequence_num(), + entropy=random.randint(0, 65535), + ) header.device_id = infuse_id header.key_metadata = key_meta - header.gps_time = gps_time - header.sequence = dev_state.gatt_sequence_num() - header.entropy = random.randint(0, 65535) # Encrypt and return payload header_bytes = bytes(header) diff --git a/src/infuse_iot/generated/rpc_definitions.py b/src/infuse_iot/generated/rpc_definitions.py index a7e6215..b7a433b 100644 --- a/src/infuse_iot/generated/rpc_definitions.py +++ b/src/infuse_iot/generated/rpc_definitions.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: ignore-errors """Autogenerated RPC definitions""" diff --git a/src/infuse_iot/generated/tdf_base.py b/src/infuse_iot/generated/tdf_base.py index de0fca6..6c65d14 100644 --- a/src/infuse_iot/generated/tdf_base.py +++ b/src/infuse_iot/generated/tdf_base.py @@ -3,7 +3,7 @@ import ctypes from typing import Generator, Any -from typing_extensions import Self +from typing_extensions import Self, cast def _public_name(internal_field): @@ -91,12 +91,15 @@ def from_buffer_consume(cls, source: bytes, offset: int = 0) -> Self: last_field = cls._fields_[-1] # Last value not a VLA - if getattr(last_field[1], "_length_", 1) != 0: + if not issubclass(last_field[1], ctypes.Array): + return cls.from_buffer_copy(source, offset) + last_field_type: ctypes.Array = last_field[1] # type: ignore + if last_field_type._length_ != 0: return cls.from_buffer_copy(source, offset) base_size = ctypes.sizeof(cls) var_name = last_field[0] - var_type = last_field[1]._type_ + var_type = last_field_type._type_ var_type_size = ctypes.sizeof(var_type) source_var_len = len(source) - base_size @@ -107,11 +110,11 @@ def from_buffer_consume(cls, source: bytes, offset: int = 0) -> Self: # Dynamically create subclass with correct length class TdfVLA(ctypes.LittleEndianStructure): name = cls.name - _fields_ = cls._fields_[:-1] + [(var_name, source_var_num * var_type)] + _fields_ = cls._fields_[:-1] + [(var_name, source_var_num * var_type)] # type: ignore _pack_ = 1 _postfix_ = cls._postfix_ _display_fmt_ = cls._display_fmt_ iter_fields = cls.iter_fields field_information = cls.field_information - return TdfVLA.from_buffer_copy(source, offset) + return cast(Self, TdfVLA.from_buffer_copy(source, offset)) diff --git a/src/infuse_iot/generated/tdf_definitions.py b/src/infuse_iot/generated/tdf_definitions.py index 1e051f7..2241e38 100644 --- a/src/infuse_iot/generated/tdf_definitions.py +++ b/src/infuse_iot/generated/tdf_definitions.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 - +# mypy: ignore-errors """Autogenerated TDF decoding logic""" import ctypes @@ -125,6 +125,7 @@ class tdf_struct_lte_cell_id_global(TdfStructBase): "tac": "{}", } + class readings: class announce(TdfReadingBase): """Common announcement packet""" @@ -802,6 +803,7 @@ class array_type(TdfReadingBase): "array": "{}", } + id_type_mapping: Dict[int, TdfReadingBase] = { 1: readings.announce, 2: readings.battery_state, diff --git a/src/infuse_iot/rpc_wrappers/lte_at_cmd.py b/src/infuse_iot/rpc_wrappers/lte_at_cmd.py index 476dd58..7d795ea 100644 --- a/src/infuse_iot/rpc_wrappers/lte_at_cmd.py +++ b/src/infuse_iot/rpc_wrappers/lte_at_cmd.py @@ -9,11 +9,9 @@ class lte_at_cmd(InfuseRpcCommand, defs.lte_at_cmd): class request(ctypes.LittleEndianStructure): - _fields_ = [] _pack_ = 1 class response(InfuseRpcCommand.VariableSizeResponse): - base_fields = [] var_name = "rsp" var_type = ctypes.c_char diff --git a/src/infuse_iot/socket_comms.py b/src/infuse_iot/socket_comms.py index d1ba448..1ee01ab 100644 --- a/src/infuse_iot/socket_comms.py +++ b/src/infuse_iot/socket_comms.py @@ -5,7 +5,7 @@ import json import enum -from typing import Dict +from typing import Dict, Any from typing_extensions import Self from infuse_iot.epacket.packet import PacketReceived, PacketOutput @@ -34,7 +34,7 @@ def __init__( def to_json(self) -> Dict: """Convert class to json dictionary""" - out = {"type": int(self.type)} + out: Dict[str, Any] = {"type": int(self.type)} if self.epacket: out["epacket"] = self.epacket.to_json() if self.connection_id: @@ -75,7 +75,7 @@ def __init__( def to_json(self) -> Dict: """Convert class to json dictionary""" - out = {"type": int(self.type)} + out: Dict[str, Any] = {"type": int(self.type)} if self.epacket: out["epacket"] = self.epacket.to_json() if self.connection_id: diff --git a/src/infuse_iot/tdf.py b/src/infuse_iot/tdf.py index d59ecb7..de275d2 100644 --- a/src/infuse_iot/tdf.py +++ b/src/infuse_iot/tdf.py @@ -3,7 +3,7 @@ import ctypes import enum -from typing import List, Generator +from typing import List, Generator, Type from infuse_iot.time import InfuseTime from infuse_iot.generated import tdf_definitions, tdf_base @@ -59,7 +59,7 @@ def offset(self): class Reading: def __init__( self, - time: float, + time: None | float, period: None | float, data: List[tdf_base.TdfReadingBase], ): @@ -71,7 +71,7 @@ def __init__(self): pass @staticmethod - def _buffer_pull(buffer: bytes, ctype: ctypes.LittleEndianStructure): + def _buffer_pull(buffer: bytes, ctype: Type[ctypes.LittleEndianStructure]): v = ctype.from_buffer_copy(buffer) b = buffer[ctypes.sizeof(ctype) :] return v, b @@ -110,6 +110,7 @@ def decode(self, buffer: bytes) -> Generator[Reading, None, None]: total_data = buffer[:total_len] buffer = buffer[total_len:] + assert buffer_time is not None time = InfuseTime.unix_time_from_epoch(buffer_time) data = [ id_type.from_buffer_consume(total_data[x : x + header.len]) diff --git a/src/infuse_iot/time.py b/src/infuse_iot/time.py index 1813fe4..0a6bc4a 100644 --- a/src/infuse_iot/time.py +++ b/src/infuse_iot/time.py @@ -6,9 +6,9 @@ class InfuseTimeSource(enum.IntFlag): NONE = 0 - GNSS = 1 - NTP = 2 - RPC = 3 + GNSS = 0x01 + NTP = 0x02 + RPC = 0x04 RECOVERED = 0x80 def __str__(self) -> str: @@ -17,7 +17,11 @@ def __str__(self) -> str: if v & self.RECOVERED: postfix = " (recovered after reboot)" v ^= self.RECOVERED - return InfuseTimeSource(v).name + postfix + flags = InfuseTimeSource(v) + if flags.name: + return flags.name + postfix + else: + return "Unknown" + postfix class InfuseTime: diff --git a/src/infuse_iot/tools/gateway.py b/src/infuse_iot/tools/gateway.py index 6001a3f..3310338 100644 --- a/src/infuse_iot/tools/gateway.py +++ b/src/infuse_iot/tools/gateway.py @@ -16,6 +16,8 @@ import io import base64 +from typing import Dict, Callable + from infuse_iot.util.argparse import ValidFile from infuse_iot.util.console import Console from infuse_iot.common import InfuseType, InfuseID @@ -50,9 +52,9 @@ class LocalRpcServer: def __init__(self, database: DeviceDatabase): self._cnt = random.randint(0, 2**31) self._ddb = database - self._queued = {} + self._queued: Dict[int, Callable | None] = {} - def generate(self, command: int, args: bytes, auth: Auth, cb): + def generate(self, command: int, args: bytes, auth: Auth, cb: Callable | None): """Generate RPC packet from arguments""" cmd_bytes = bytes(rpc.RequestHeader(self._cnt, command)) + args cmd_pkt = PacketOutputRouted( @@ -60,6 +62,7 @@ def generate(self, command: int, args: bytes, auth: Auth, cb): InfuseType.RPC_CMD, cmd_bytes, ) + assert self._ddb.gateway is not None cmd_pkt.route[0].infuse_id = self._ddb.gateway self._queued[self._cnt] = cb self._cnt += 1 @@ -114,7 +117,7 @@ def __init__( self.ddb = ddb self.rpc = rpc_server - def query_device_key(self, cb_event: threading.Event = None): + def query_device_key(self, cb_event: threading.Event | None = None): def security_state_done(pkt: PacketReceived, _: int, response: bytes): cloud_key = response[:32] device_key = response[32:64] @@ -211,6 +214,7 @@ def _handle_serial_frame(self, frame: bytearray): try: decoded = PacketReceived.from_serial(self._common.ddb, frame) except NoKeyError: + assert self._common.ddb.gateway is not None if not self._common.ddb.has_network_id(self._common.ddb.gateway): # Need to know network ID before we can query the device key self._common.port.ping() @@ -256,7 +260,7 @@ def __init__( common: CommonThreadState, ): self._common = common - self._queue = queue.Queue() + self._queue: queue.Queue = queue.Queue() super().__init__(self._iter) def send(self, pkt): @@ -269,6 +273,7 @@ def _handle_epacket_send(self, req: GatewayRequest): return pkt = req.epacket + assert pkt is not None # Construct routed output if pkt.infuse_id == InfuseID.GATEWAY: @@ -321,6 +326,8 @@ def _bt_connect_cb(self, pkt: PacketReceived, rc: int, response: bytes): self._common.server.broadcast(rsp) def _handle_conn_request(self, req: GatewayRequest): + assert req.connection_id is not None + if req.connection_id == InfuseID.GATEWAY: # Local gateway always connected rsp = ClientNotification( @@ -339,10 +346,8 @@ def _handle_conn_request(self, req: GatewayRequest): self._common.server.broadcast(rsp) return - device_info = self._common.ddb.devices[req.connection_id] - connect_args = defs.bt_connect_infuse.request( - device_info.bt_addr.to_rpc_struct(), + state.bt_addr.to_rpc_struct(), 10000, defs.rpc_enum_infuse_bt_characteristic.COMMAND, 0, @@ -358,6 +363,8 @@ def _handle_conn_request(self, req: GatewayRequest): self._common.port.write(encrypted) def _handle_conn_release(self, req: GatewayRequest): + assert req.connection_id is not None + if req.connection_id == InfuseID.GATEWAY: # Local gateway always connected return @@ -369,7 +376,7 @@ def _handle_conn_release(self, req: GatewayRequest): disconnect_args = defs.bt_disconnect.request(state.bt_addr.to_rpc_struct()) cmd = self._common.rpc.generate( - defs.bt_disconnect.COMMAND_ID, bytes(disconnect_args) + defs.bt_disconnect.COMMAND_ID, bytes(disconnect_args), Auth.DEVICE, None ) encrypted = cmd.to_serial(self._common.ddb) Console.log_tx(cmd.ptype, len(encrypted)) diff --git a/src/infuse_iot/tools/native_bt.py b/src/infuse_iot/tools/native_bt.py index 2af0a4a..9be04ef 100644 --- a/src/infuse_iot/tools/native_bt.py +++ b/src/infuse_iot/tools/native_bt.py @@ -47,11 +47,15 @@ def __init__(self, args): Console.init() def simple_callback(self, device: BLEDevice, data: AdvertisementData): - addr = interface.Address.BluetoothLeAddr(0, BtLeAddress(device.address)) + addr = interface.Address( + interface.Address.BluetoothLeAddr( + 0, BtLeAddress.integer_value(device.address) + ) + ) rssi = data.rssi payload = data.manufacturer_data[self.infuse_manu] - hdr, decr = CtypeBtAdvFrame.decrypt(self.database, payload) + hdr, decr = CtypeBtAdvFrame.decrypt(self.database, addr.val, payload) hop = HopReceived( hdr.device_id, diff --git a/src/infuse_iot/tools/provision.py b/src/infuse_iot/tools/provision.py index e17b691..6af5cb2 100644 --- a/src/infuse_iot/tools/provision.py +++ b/src/infuse_iot/tools/provision.py @@ -77,7 +77,7 @@ def __init__(self, args): key, val = meta.strip().split("=", 1) self._metadata[key.strip()] = val - def nrf_device_info(self, api: LowLevel.API) -> tuple[int, int]: + def nrf_device_info(self, api: LowLevel.API) -> tuple[str, int, int]: """Retrive device ID and customer UICR address""" device_id_offsets = { # nRF52840 only @@ -120,6 +120,9 @@ def nrf_device_info(self, api: LowLevel.API) -> tuple[int, int]: dev_id_bytes = bytes(api.read(dev_id_addr, 8)) dev_id = int.from_bytes(dev_id_bytes, "big") + assert uicr_addr is not None + assert dev_id is not None + return soc, uicr_addr, dev_id def create_device(self, client, soc, hardware_id_str): @@ -223,7 +226,7 @@ def run(self): f"HW ID 0x{hardware_id:016x} already provisioned as 0x{desired.device_id:016x}" ) else: - if current_bytes != len(current_bytes) * b"\xFF": + if current_bytes != len(current_bytes) * b"\xff": print( f"HW ID 0x{hardware_id:016x} already has incorrect provisioning info, recover device" ) diff --git a/src/infuse_iot/tools/rpc.py b/src/infuse_iot/tools/rpc.py index c5cf386..0163a3c 100644 --- a/src/infuse_iot/tools/rpc.py +++ b/src/infuse_iot/tools/rpc.py @@ -58,7 +58,7 @@ def add_parser(cls, parser): cmd_parser.set_defaults(rpc_class=cmd_cls) cmd_cls.add_parser(cmd_parser) - def __init__(self, args): + def __init__(self, args: argparse.Namespace): self._args = args self._client = LocalClient(default_multicast_address(), 10.0) self._command: InfuseRpcCommand = args.rpc_class(args) diff --git a/src/infuse_iot/util/argparse.py b/src/infuse_iot/util/argparse.py index d6d874a..b4a82dc 100644 --- a/src/infuse_iot/util/argparse.py +++ b/src/infuse_iot/util/argparse.py @@ -4,11 +4,13 @@ import pathlib import re +from typing import cast + class ValidFile: """Filesystem path that exists""" - def __new__(cls, string) -> pathlib.Path: + def __new__(cls, string) -> pathlib.Path: # type: ignore p = pathlib.Path(string) if p.exists(): return p @@ -19,8 +21,7 @@ def __new__(cls, string) -> pathlib.Path: class BtLeAddress: """Bluetooth Low-Energy address""" - def __new__(cls, string) -> int: - + def __new__(cls, string) -> int: # type: ignore pattern = r"((([0-9a-fA-F]{2}):){5})([0-9a-fA-F]{2})" if re.match(pattern, string): @@ -32,3 +33,7 @@ def __new__(cls, string) -> int: except ValueError: raise argparse.ArgumentTypeError(f"{string} is not a Bluetooth address") return addr + + @classmethod + def integer_value(cls, string) -> int: + return cast(int, cls(string)) diff --git a/tox.ini b/tox.ini index 69fb1e2..f5d4d32 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,9 @@ deps = pytest pytest-cov types-PyYAML + pandas-stubs + types-tabulate + mypy ruff setenv = # For instance: ./.tox/py3/tmp/ @@ -43,3 +46,4 @@ setenv = commands = python -m pytest --basetemp='{envtmpdir}/pytest with space/' python -m ruff check '{toxinidir}' --extend-exclude 'api_client|tdf_definitions.py|rpc_definitions.py' + python -m mypy '{toxinidir}'