Skip to content

Commit

Permalink
Reuse ZDO Initializers to create Endpoint objects on EZSP device (#599)
Browse files Browse the repository at this point in the history
* Ensure device endpoints sync with register_endpoints

This way if we need a reference to one we can grab it in a way we expect

* Make the linter and tests happy

* Better names, better comments.

* Merge complex endpoint functionality into simple endpoint

* Refactor unit tests to allow slightly better mocking of startup

* Simplify endpoint creation

---------

Co-authored-by: puddly <32534428+puddly@users.noreply.github.com>
  • Loading branch information
konistehrad and puddly committed Dec 30, 2023
1 parent f9f4c3f commit 758803d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 51 deletions.
28 changes: 21 additions & 7 deletions bellows/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ControllerApplication(zigpy.application.ControllerApplication):
def __init__(self, config: dict):
super().__init__(config)
self._ctrl_event = asyncio.Event()
self._created_device_endpoints: list[zdo_t.SimpleDescriptor] = []
self._ezsp = None
self._multicast = None
self._mfg_id_task: asyncio.Task | None = None
Expand Down Expand Up @@ -116,9 +117,12 @@ async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None:
descriptor.input_clusters,
descriptor.output_clusters,
)

if status != t.EmberStatus.SUCCESS:
raise StackAlreadyRunning()

self._created_device_endpoints.append(descriptor)

async def cleanup_tc_link_key(self, ieee: t.EUI64) -> None:
"""Remove tc link_key for the given device."""
(index,) = await self._ezsp.findKeyTableEntry(ieee, True)
Expand Down Expand Up @@ -150,6 +154,8 @@ async def connect(self) -> None:
raise

self._ezsp = ezsp

self._created_device_endpoints.clear()
await self.register_endpoints()

async def _ensure_network_running(self) -> bool:
Expand Down Expand Up @@ -198,10 +204,15 @@ async def start_network(self):
ezsp.add_callback(self.ezsp_callback_handler)
self.controller_event.set()

group_membership = {}

try:
db_device = self.get_device(ieee=self.state.node_info.ieee)
except KeyError:
db_device = None
pass
else:
if 1 in db_device.endpoints:
group_membership = db_device.endpoints[1].member_of

ezsp_device = zigpy.device.Device(
application=self,
Expand All @@ -210,15 +221,18 @@ async def start_network(self):
)
self.devices[self.state.node_info.ieee] = ezsp_device

# The coordinator device does not respond to attribute reads
ezsp_device.endpoints[1] = EZSPEndpoint(ezsp_device, 1)
ezsp_device.model = ezsp_device.endpoints[1].model
ezsp_device.manufacturer = ezsp_device.endpoints[1].manufacturer
# The coordinator device does not respond to attribute reads so we have to
# divine the internal NCP state.
for zdo_desc in self._created_device_endpoints:
ep = EZSPEndpoint(ezsp_device, zdo_desc.endpoint, zdo_desc)
ezsp_device.endpoints[zdo_desc.endpoint] = ep
ezsp_device.model = ep.model
ezsp_device.manufacturer = ep.manufacturer

await ezsp_device.schedule_initialize()

# Group membership is stored in the database for EZSP coordinators
if db_device is not None and 1 in db_device.endpoints:
ezsp_device.endpoints[1].member_of.update(db_device.endpoints[1].member_of)
ezsp_device.endpoints[1].member_of.update(group_membership)

self._multicast = bellows.multicast.Multicast(ezsp)
await self._multicast.startup(ezsp_device)
Expand Down
41 changes: 35 additions & 6 deletions bellows/zigbee/device.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,51 @@
from __future__ import annotations

import logging
import typing

import zigpy.device
import zigpy.endpoint
import zigpy.util
import zigpy.zdo
import zigpy.profiles.zgp
import zigpy.profiles.zha
import zigpy.profiles.zll
import zigpy.zdo.types as zdo_t

import bellows.types as t

if typing.TYPE_CHECKING:
import zigpy.application # pragma: no cover

LOGGER = logging.getLogger(__name__)

PROFILE_TO_DEVICE_TYPE = {
zigpy.profiles.zha.PROFILE_ID: zigpy.profiles.zha.DeviceType,
zigpy.profiles.zll.PROFILE_ID: zigpy.profiles.zll.DeviceType,
zigpy.profiles.zgp.PROFILE_ID: zigpy.profiles.zgp.DeviceType,
}


class EZSPEndpoint(zigpy.endpoint.Endpoint):
def __init__(
self,
device: zigpy.device.Device,
endpoint_id: int,
descriptor: zdo_t.SimpleDescriptor,
) -> None:
super().__init__(device, endpoint_id)

self.profile_id = descriptor.profile

if self.profile_id in PROFILE_TO_DEVICE_TYPE:
self.device_type = PROFILE_TO_DEVICE_TYPE[self.profile_id](
descriptor.device_type
)
else:
self.device_type = descriptor.device_type

for cluster in descriptor.input_clusters:
self.add_input_cluster(cluster)

for cluster in descriptor.output_clusters:
self.add_output_cluster(cluster)

self.status = zigpy.endpoint.Status.ZDO_INIT

@property
def manufacturer(self) -> str:
"""Manufacturer."""
Expand Down
90 changes: 52 additions & 38 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import logging
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch, sentinel

Expand Down Expand Up @@ -114,7 +115,6 @@ def aps():
return f


@patch("zigpy.device.Device._initialize", new=AsyncMock())
def _create_app_for_startup(
app,
nwk_type,
Expand Down Expand Up @@ -206,6 +206,14 @@ async def mock_leave(*args, **kwargs):
),
]
)
ezsp_mock.getMulticastTableEntry = AsyncMock(
return_value=[
t.EmberStatus.SUCCESS,
t.EmberMulticastTableEntry(multicastId=0x0000, endpoint=0, networkIndex=0),
]
)
ezsp_mock.setMulticastTableEntry = AsyncMock(return_value=[t.EmberStatus.SUCCESS])

app.permit = AsyncMock()

def form_network():
Expand All @@ -220,10 +228,11 @@ def form_network():
return ezsp_mock


async def _test_startup(
@contextlib.contextmanager
def mock_for_startup(
app,
nwk_type,
ieee,
nwk_type=t.EmberNodeType.COORDINATOR,
auto_form=False,
init=0,
ezsp_version=4,
Expand All @@ -234,10 +243,25 @@ async def _test_startup(
app, nwk_type, ieee, auto_form, init, ezsp_version, board_info, network_state
)

p1 = patch("bellows.ezsp.EZSP", return_value=ezsp_mock)
p2 = patch.object(bellows.multicast.Multicast, "startup")
with patch("bellows.ezsp.EZSP", return_value=ezsp_mock), patch(
"zigpy.device.Device._initialize", new=AsyncMock()
):
yield ezsp_mock


with p1, p2 as multicast_mock:
async def _test_startup(
app,
nwk_type,
ieee,
auto_form=False,
init=0,
ezsp_version=4,
board_info=True,
network_state=t.EmberNetworkStatus.JOINED_NETWORK,
):
with mock_for_startup(
app, ieee, nwk_type, auto_form, init, ezsp_version, board_info, network_state
) as ezsp_mock:
await app.startup(auto_form=auto_form)

if ezsp_version > 6:
Expand All @@ -247,7 +271,6 @@ async def _test_startup(

assert ezsp_mock.write_config.call_count == 1
assert ezsp_mock.addEndpoint.call_count >= 2
assert multicast_mock.await_count == 1


async def test_startup(app, ieee):
Expand Down Expand Up @@ -1166,7 +1189,7 @@ async def test_shutdown(app):
@pytest.fixture
def coordinator(app, ieee):
dev = zigpy.device.Device(app, ieee, 0x0000)
dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1)
dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1, MagicMock())
dev.model = dev.endpoints[1].model
dev.manufacturer = dev.endpoints[1].manufacturer

Expand Down Expand Up @@ -1505,42 +1528,32 @@ async def test_ensure_network_running_not_joined_success(app):

async def test_startup_coordinator_existing_groups_joined(app, ieee):
"""Coordinator joins groups loaded from the database."""
with mock_for_startup(app, ieee) as ezsp_mock:
await app.connect()

app._ensure_network_running = AsyncMock()
app._ezsp.update_policies = AsyncMock()
app.load_network_info = AsyncMock()
app.state.node_info.ieee = ieee

db_device = app.add_device(ieee, 0x0000)
db_ep = db_device.add_endpoint(1)

app.groups.add_group(0x1234, "Group Name", suppress_event=True)
app.groups[0x1234].add_member(db_ep, suppress_event=True)
db_device = app.add_device(ieee, 0x0000)
db_ep = db_device.add_endpoint(1)

p1 = patch.object(bellows.multicast.Multicast, "_initialize")
p2 = patch.object(bellows.multicast.Multicast, "subscribe")
app.groups.add_group(0x1234, "Group Name", suppress_event=True)
app.groups[0x1234].add_member(db_ep, suppress_event=True)

with p1 as p1, p2 as p2:
await app.start_network()

p2.assert_called_once_with(0x1234)
assert ezsp_mock.setMulticastTableEntry.mock_calls == [
call(
0,
t.EmberMulticastTableEntry(multicastId=0x1234, endpoint=1, networkIndex=0),
)
]


async def test_startup_new_coordinator_no_groups_joined(app, ieee):
"""Coordinator freshy added to the database has no groups to join."""

app._ensure_network_running = AsyncMock()
app._ezsp.update_policies = AsyncMock()
app.load_network_info = AsyncMock()
app.state.node_info.ieee = ieee

p1 = patch.object(bellows.multicast.Multicast, "_initialize")
p2 = patch.object(bellows.multicast.Multicast, "subscribe")

with p1 as p1, p2 as p2:
with mock_for_startup(app, ieee) as ezsp_mock:
await app.connect()
await app.start_network()

p2.assert_not_called()
assert ezsp_mock.setMulticastTableEntry.mock_calls == []


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1628,22 +1641,23 @@ async def test_connect_failure(
assert len(ezsp_mock.close.mock_calls) == 1


async def test_repair_tclk_partner_ieee(app: ControllerApplication) -> None:
async def test_repair_tclk_partner_ieee(
app: ControllerApplication, ieee: zigpy_t.EUI64
) -> None:
"""Test that EZSP is reset after repairing TCLK."""
app._ensure_network_running = AsyncMock()
app._reset = AsyncMock()
app.load_network_info = AsyncMock()

with patch(
with mock_for_startup(app, ieee), patch(
"bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee",
AsyncMock(return_value=False),
):
await app.connect()
await app.start_network()

assert len(app._reset.mock_calls) == 0
app._reset.reset_mock()

with patch(
with mock_for_startup(app, ieee), patch(
"bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee",
AsyncMock(return_value=True),
):
Expand Down

0 comments on commit 758803d

Please sign in to comment.