Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
line-length = 120
indent-width = 4

exclude = [
"src/infuse_iot/api_client/"
]

[lint]
select = [
"B", # flake8-bugbear
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"SIM", # flake8-simplify
"UP", # pyupgrade
"W", # pycodestyle warnings
]
ignore = [
"SIM105", # Allow try-except-pass
"SIM108", # Allow if-else blocks instead of forcing ternary operator
]

[format]
line-ending = "lf"
15 changes: 6 additions & 9 deletions src/infuse_iot/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
__copyright__ = "Copyright 2024, Embeint Inc"

import argparse
import sys
import pkgutil
import importlib.util
import pkgutil
import sys

import argcomplete

import infuse_iot.tools
Expand All @@ -23,9 +24,7 @@ class InfuseApp:
def __init__(self):
self.args = None
self.parser = argparse.ArgumentParser("infuse")
self.parser.add_argument(
"--version", action="version", version=f"{__version__}"
)
self.parser.add_argument("--version", action="version", version=f"{__version__}")
self._tools = {}
# Load tools
self._load_tools(self.parser)
Expand All @@ -40,17 +39,15 @@ def run(self, argv):
tool.run()

def _load_tools(self, parser: argparse.ArgumentParser):
tools_parser = parser.add_subparsers(
title="commands", metavar="<command>", required=True
)
tools_parser = parser.add_subparsers(title="commands", metavar="<command>", required=True)

# Iterate over tools
for _, name, _ in pkgutil.walk_packages(infuse_iot.tools.__path__):
full_name = f"{infuse_iot.tools.__name__}.{name}"
module = importlib.import_module(full_name)

# Add tool to parser
tool_cls: InfuseCommand = getattr(module, "SubCommand")
tool_cls: InfuseCommand = module.SubCommand
parser = tools_parser.add_parser(
tool_cls.NAME,
help=tool_cls.HELP,
Expand Down
7 changes: 2 additions & 5 deletions src/infuse_iot/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import argparse
import ctypes

from typing import List, Type, Tuple


from infuse_iot.epacket.packet import Auth


Expand Down Expand Up @@ -63,9 +60,9 @@ def handle_response(self, return_code, response):
raise NotImplementedError

class VariableSizeResponse:
base_fields: List[Tuple[str, Type[ctypes._SimpleCData]]] = []
base_fields: list[tuple[str, type[ctypes._SimpleCData]]] = []
var_name = "x"
var_type: Type[ctypes._SimpleCData] = ctypes.c_ubyte
var_type: type[ctypes._SimpleCData] = ctypes.c_ubyte

@classmethod
def from_buffer_copy(cls, source, offset=0):
Expand Down
4 changes: 1 addition & 3 deletions src/infuse_iot/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,5 @@ def load_network(network_id: int):
username = f"network-{network_id:06x}"
key = keyring.get_password("infuse-iot", username)
if key is None:
raise FileNotFoundError(
f"Network key {network_id:06x} does not exist in keyring"
)
raise FileNotFoundError(f"Network key {network_id:06x} does not exist in keyring")
return yaml.safe_load(key)
39 changes: 14 additions & 25 deletions src/infuse_iot/database.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/usr/bin/env python3

import binascii
import base64
from typing import Dict, Tuple
import binascii

from infuse_iot.api_client import Client
from infuse_iot.api_client.api.default import get_shared_secret
from infuse_iot.api_client.models import Key
from infuse_iot.util.crypto import hkdf_derive
from infuse_iot.epacket.interface import Address as InterfaceAddress
from infuse_iot.credentials import get_api_key, load_network
from infuse_iot.epacket.interface import Address as InterfaceAddress
from infuse_iot.util.crypto import hkdf_derive


class NoKeyError(KeyError):
Expand All @@ -36,9 +35,10 @@ class DeviceDatabase:
"""Database of current device state"""

_network_keys = {
0x000000: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f",
0x000000: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"
b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f",
}
_derived_keys: Dict[Tuple[int, bytes, int], bytes] = {}
_derived_keys: dict[tuple[int, bytes, int], bytes] = {}

class DeviceState:
"""Device State"""
Expand All @@ -64,8 +64,8 @@ def gatt_sequence_num(self):

def __init__(self) -> None:
self.gateway: int | None = None
self.devices: Dict[int, DeviceDatabase.DeviceState] = {}
self.bt_addr: Dict[InterfaceAddress.BluetoothLeAddr, int] = {}
self.devices: dict[int, DeviceDatabase.DeviceState] = {}
self.bt_addr: dict[InterfaceAddress.BluetoothLeAddr, int] = {}

def observe_device(
self,
Expand All @@ -82,21 +82,14 @@ def observe_device(
if network_id is not None:
self.devices[address].network_id = network_id
if device_id is not None:
if (
self.devices[address].device_id is not None
and self.devices[address].device_id != device_id
):
raise DeviceKeyChangedError(
f"Device key for {address:016x} has changed"
)
if self.devices[address].device_id is not None and self.devices[address].device_id != device_id:
raise DeviceKeyChangedError(f"Device key for {address:016x} has changed")
self.devices[address].device_id = device_id
if bt_addr is not None:
self.bt_addr[bt_addr] = address
self.devices[address].bt_addr = bt_addr

def observe_security_state(
self, address: int, cloud_key: bytes, device_key: bytes, network_id: int
) -> None:
def observe_security_state(self, address: int, cloud_key: bytes, device_key: bytes, network_id: int) -> None:
"""Update device state based on security_state response"""
if address not in self.devices:
self.devices[address] = self.DeviceState(address)
Expand All @@ -121,16 +114,14 @@ def _network_key(self, network_id: int, interface: bytes, gps_time: int) -> byte
try:
info = load_network(network_id)
except FileNotFoundError:
raise UnknownNetworkError
raise UnknownNetworkError from None
self._network_keys[network_id] = info["key"]
base = self._network_keys[network_id]
time_idx = gps_time // (60 * 60 * 24)

key_id = (network_id, interface, time_idx)
if key_id not in self._derived_keys:
self._derived_keys[key_id] = hkdf_derive(
base, time_idx.to_bytes(4, "little"), interface
)
self._derived_keys[key_id] = hkdf_derive(base, time_idx.to_bytes(4, "little"), interface)

return self._derived_keys[key_id]

Expand All @@ -155,9 +146,7 @@ def has_network_id(self, address: int) -> bool:
return False
return self.devices[address].network_id is not None

def infuse_id_from_bluetooth(
self, bt_addr: InterfaceAddress.BluetoothLeAddr
) -> int | None:
def infuse_id_from_bluetooth(self, bt_addr: InterfaceAddress.BluetoothLeAddr) -> int | None:
"""Get Bluetooth address associated with device"""
return self.bt_addr.get(bt_addr, None)

Expand Down
Loading