diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3643c6d..b6b61cf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,11 @@ name: Python Test # yamllint disable-line rule:truthy -on: [push, pull_request] +on: + push: + branches: + - main + pull_request: jobs: build: diff --git a/pyproject.toml b/pyproject.toml index e073f3c..4f0d8ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "tabulate", "typing-extensions", "keyring", + "pyyaml", ] [project.license] diff --git a/src/infuse_iot/credentials.py b/src/infuse_iot/credentials.py index 08d2e40..d08ec45 100644 --- a/src/infuse_iot/credentials.py +++ b/src/infuse_iot/credentials.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 import keyring +import yaml -def set_api_key(api_key): +def set_api_key(api_key: str): """ Save the Infuse-IoT API key to the keyring module """ @@ -18,3 +19,24 @@ def get_api_key(): if key is None: raise FileNotFoundError("API key does not exist in keyring") return key + + +def save_network(network_id: int, network_info: str): + """ + Save an Infuse-IoT network key to the keyring module + """ + username = f"network-{network_id:06x}" + keyring.set_password("infuse-iot", username, network_info) + + +def load_network(network_id: int): + """ + Retrieve an Infuse-IoT network key from the keyring module + """ + username = f"network-{network_id:06x}" + key = keyring.get_password("infuse-iot", username) + if key is None: + raise FileNotFoundError( + f"Network key {network_id:06x} does not exist in keyring" + ) + return yaml.safe_load(key) diff --git a/src/infuse_iot/database.py b/src/infuse_iot/database.py index 531a2ed..d0a11c2 100644 --- a/src/infuse_iot/database.py +++ b/src/infuse_iot/database.py @@ -8,23 +8,36 @@ from infuse_iot.api_client.api.default import get_shared_secret from infuse_iot.api_client.models import Key from infuse_iot.util.crypto import hkdf_derive -from infuse_iot.credentials import get_api_key +from infuse_iot.credentials import get_api_key, load_network class NoKeyError(KeyError): - pass + """Generic key not found error""" -class KeyChangedError(KeyError): - pass +class UnknownNetworkError(NoKeyError): + """Requested network is not known""" + + +class DeviceUnknownDeviceKey(NoKeyError): + """Device key is not known for requested device""" + + +class DeviceUnknownNetworkKey(NoKeyError): + """Network key is not known for requested device""" + + +class DeviceKeyChangedError(KeyError): + """Device key for the requested device has changed""" class DeviceDatabase: """Database of current device state""" _network_keys = { - 0x00: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f", + 0x000000: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f", } + _derived_keys = {} class DeviceState: """Device State""" @@ -55,7 +68,9 @@ def observe_serial( self.devices[address].device_id is not None and self.devices[address].device_id != device_id ): - raise KeyChangedError(f"Device key for {address:016x} has changed") + raise DeviceKeyChangedError( + f"Device key for {address:016x} has changed" + ) self.devices[address].device_id = device_id def observe_security_state( @@ -79,6 +94,24 @@ def observe_security_state( key = base64.b64decode(response.key) self.devices[address].shared_key = key + def _network_key(self, network_id: int, interface: str, gps_time: int): + if network_id not in self._network_keys: + try: + info = load_network(network_id) + except FileNotFoundError: + raise UnknownNetworkError + self._network_keys[network_id] = info["key"] + base = self._network_keys[network_id] + time_idx = gps_time // (60 * 60 * 24) + + key_id = (network_id, interface, time_idx) + if key_id not in self._derived_keys: + self._derived_keys[key_id] = hkdf_derive( + base, time_idx.to_bytes(4, "little"), interface + ) + + return self._derived_keys[key_id] + def _serial_key(self, base, time_idx): return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"serial") @@ -88,29 +121,33 @@ def _bt_adv_key(self, base, time_idx): def has_public_key(self, address: int): """Does the database have the public key for this device?""" if address not in self.devices: - print(address, "not in list") return False return self.devices[address].public_key is not None + def has_network_id(self, address: int): + """Does the database know the network ID for this device?""" + if address not in self.devices: + return False + return self.devices[address].network_id is not None + def serial_network_key(self, address: int, gps_time: int): """Network key for serial interface""" if address not in self.devices: - raise NoKeyError - base = self._network_keys[self.devices[address].network_id] - time_idx = gps_time // (60 * 60 * 24) + raise DeviceUnknownNetworkKey + network_id = self.devices[address].network_id - return self._serial_key(base, time_idx) + return self._network_key(network_id, b"serial", gps_time) def serial_device_key(self, address: int, gps_time: int): """Device key for serial interface""" if address not in self.devices: - raise NoKeyError + raise DeviceUnknownDeviceKey d = self.devices[address] if d.device_id is None: - raise NoKeyError + raise DeviceUnknownDeviceKey base = self.devices[address].shared_key if base is None: - raise NoKeyError + raise DeviceUnknownDeviceKey time_idx = gps_time // (60 * 60 * 24) return self._serial_key(base, time_idx) @@ -118,22 +155,21 @@ def serial_device_key(self, address: int, gps_time: int): def bt_adv_network_key(self, address: int, gps_time: int): """Network key for Bluetooth advertising interface""" if address not in self.devices: - raise NoKeyError - base = self._network_keys[self.devices[address].network_id] - time_idx = gps_time // (60 * 60 * 24) + raise DeviceUnknownNetworkKey + network_id = self.devices[address].network_id - return self._bt_adv_key(base, time_idx) + return self._network_key(network_id, b"bt_adv", gps_time) def bt_adv_device_key(self, address: int, gps_time: int): """Device key for Bluetooth advertising interface""" if address not in self.devices: - raise NoKeyError + raise DeviceUnknownDeviceKey d = self.devices[address] if d.device_id is None: - raise NoKeyError + raise DeviceUnknownDeviceKey base = self.devices[address].shared_key if base is None: - raise NoKeyError + raise DeviceUnknownDeviceKey time_idx = gps_time // (60 * 60 * 24) return self._bt_adv_key(base, time_idx) diff --git a/src/infuse_iot/tools/credentials.py b/src/infuse_iot/tools/credentials.py new file mode 100644 index 0000000..b9071d4 --- /dev/null +++ b/src/infuse_iot/tools/credentials.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +"""Manage Infuse-IoT credentials""" + +__author__ = "Jordan Yates" +__copyright__ = "Copyright 2024, Embeint Inc" + +import yaml + +from infuse_iot.util.argparse import ValidFile +from infuse_iot.commands import InfuseCommand + +from infuse_iot import credentials + + +class SubCommand(InfuseCommand): + NAME = "credentials" + HELP = "Manage Infuse-IoT credentials" + DESCRIPTION = "Manage Infuse-IoT credentials" + + @classmethod + def add_parser(cls, parser): + parser.add_argument("--api-key", type=str, help="Set Infuse-IoT API key") + parser.add_argument( + "--network", type=ValidFile, help="Load network credentials from file" + ) + + def __init__(self, args): + self.args = args + + def run(self): + if self.args.api_key is not None: + credentials.set_api_key(self.args.api_key) + if self.args.network is not None: + # Read the file + with self.args.network.open("r") as f: + content = f.read() + # Validate it is valid yaml + network_info = yaml.safe_load(content) + credentials.save_network(network_info["id"], content) diff --git a/src/infuse_iot/tools/gateway.py b/src/infuse_iot/tools/gateway.py index 5ab8c32..6115137 100644 --- a/src/infuse_iot/tools/gateway.py +++ b/src/infuse_iot/tools/gateway.py @@ -21,8 +21,10 @@ from infuse_iot.commands import InfuseCommand from infuse_iot.serial_comms import RttPort, SerialPort, SerialFrame from infuse_iot.socket_comms import LocalServer, default_multicast_address -from infuse_iot.database import DeviceDatabase -from infuse_iot.credentials import set_api_key +from infuse_iot.database import ( + DeviceDatabase, + NoKeyError, +) from infuse_iot.epacket import ( InfuseType, @@ -177,16 +179,23 @@ class memfault_chunk_header(ctypes.LittleEndianStructure): f"Memfault Chunk {hdr.cnt:3d}: {base64.b64encode(chunk).decode('utf-8')}" ) - def _handle_serial_frame(self, frame): + def _handle_serial_frame(self, frame: bytearray): try: # Decode the serial packet try: decoded = PacketReceived.from_serial(self._common.ddb, frame) - except KeyError: - self._common.query_device_key(None) - Console.log_info( - f"Dropping {len(frame)} byte packet to query device key..." - ) + except NoKeyError: + if not self._common.ddb.has_network_id(self._common.ddb.gateway): + # Need to know network ID before we can query the device key + self._common.port.ping() + Console.log_info( + f"Dropping {len(frame)} byte packet to query network ID..." + ) + else: + self._common.query_device_key(None) + Console.log_info( + f"Dropping {len(frame)} byte packet to query device key..." + ) return except cryptography.exceptions.InvalidTag as e: Console.log_error(f"Failed to decode {len(frame)} byte packet {e}") @@ -280,11 +289,8 @@ def add_parser(cls, parser): type=argparse.FileType("w"), help="Save serial output to file", ) - parser.add_argument("--api-key", type=str, help="Update saved API key") def __init__(self, args): - if args.api_key is not None: - set_api_key(args.api_key) if args.serial is not None: self.port = SerialPort(args.serial) elif args.rtt is not None: