diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b6b104b8..26b212e3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,17 +67,35 @@ jobs: ${{ steps.cp39.outputs.python-path }} -m pip install .[test] echo "MANUAL_OS_SET=Windows" >> $GITHUB_ENV - - name: Perform faster tests + - name: Perform ChimeraPy utils tests run: | - ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m "not slow" test + ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m test/core ${{ steps.cp39.outputs.python-path }} -m coverage combine --append - mv chimerapy-engine-test.log chimerapy-engine-test-fast.log + mv chimerapy-engine-test.log chimerapy-engine-test-utils.log - - name: Perform slower tests + - name: Perform ChimeraPy logger tests run: | - ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m "slow" test + ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m test/logger ${{ steps.cp39.outputs.python-path }} -m coverage combine --append - mv chimerapy-engine-test.log chimerapy-engine-test-slow.log + mv chimerapy-engine-test.log chimerapy-engine-test-logger.log + + - name: Perform ChimeraPy Node tests + run: | + ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m test/node + ${{ steps.cp39.outputs.python-path }} -m coverage combine --append + mv chimerapy-engine-test.log chimerapy-engine-test-node.log + + - name: Perform ChimeraPy Worker tests + run: | + ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m test/worker + ${{ steps.cp39.outputs.python-path }} -m coverage combine --append + mv chimerapy-engine-test.log chimerapy-engine-test-worker.log + + - name: Perform ChimeraPy Manager tests + run: | + ${{ steps.cp39.outputs.python-path }} -m coverage run --source=chimerapy -m pytest -v --reruns 5 --color yes --reruns-delay 5 -m test/manager + ${{ steps.cp39.outputs.python-path }} -m coverage combine --append + mv chimerapy-engine-test.log chimerapy-engine-test-manager.log - name: Combine test logs run : | diff --git a/chimerapy/engine/data_protocols.py b/chimerapy/engine/data_protocols.py index 79d7fd03..c34d073e 100644 --- a/chimerapy/engine/data_protocols.py +++ b/chimerapy/engine/data_protocols.py @@ -1,9 +1,18 @@ import datetime +import enum +import logging +import typing from dataclasses import dataclass, field -from typing import Dict +from typing import Any, Dict, List, Literal, Optional, Union + +if typing.TYPE_CHECKING: + from .graph import Graph + from .states import NodeState from dataclasses_json import DataClassJsonMixin +from .networking import DataChunk + @dataclass class NodePubEntry(DataClassJsonMixin): @@ -18,6 +27,7 @@ class NodePubTable(DataClassJsonMixin): @dataclass class NodeDiagnostics(DataClassJsonMixin): + node_id: str = "" timestamp: str = field( default_factory=lambda: str(datetime.datetime.now().isoformat()) ) # ISO str @@ -26,3 +36,85 @@ class NodeDiagnostics(DataClassJsonMixin): memory_usage: float = 0 # KB cpu_usage: float = 0 # percentage num_of_steps: int = 0 + + +######################################################################## +## Manager specific +######################################################################## + + +@dataclass +class RegisterMethodResponseData(DataClassJsonMixin): + success: bool + result: Dict[str, Any] + + +@dataclass +class UpdateSendArchiveData(DataClassJsonMixin): + worker_id: str + success: bool + + +@dataclass +class CommitData(DataClassJsonMixin): + graph: "Graph" + mapping: Dict[str, List[str]] + context: Literal["multiprocessing", "threading"] = "multiprocessing" + send_packages: Optional[List[Dict[str, Any]]] = None + + +######################################################################## +## Worker specific +######################################################################## + + +@dataclass +class ConnectData(DataClassJsonMixin): + method: Literal["ip", "zeroconf"] + host: Optional[str] = None + port: Optional[int] = None + + +@dataclass +class GatherData(DataClassJsonMixin): + node_id: str + output: Union[DataChunk, List[int]] + + +@dataclass +class ResultsData(DataClassJsonMixin): + node_id: str + success: bool + output: Any + + +@dataclass +class ServerMessage(DataClassJsonMixin): + signal: enum.Enum + data: Dict[str, Any] = field(default_factory=dict) + client_id: Optional[str] = None + + +######################################################################## +## Node specific +######################################################################## + + +@dataclass +class PreSetupData(DataClassJsonMixin): + state: "NodeState" + logger: logging.Logger + + +@dataclass +class RegisteredMethod(DataClassJsonMixin): + name: str + style: Literal["concurrent", "blocking", "reset"] = "concurrent" + params: Dict[str, str] = field(default_factory=dict) + + +@dataclass +class RegisteredMethodData(DataClassJsonMixin): + node_id: str + method_name: str + params: Dict[str, Any] = field(default_factory=dict) diff --git a/chimerapy/engine/eventbus/__init__.py b/chimerapy/engine/eventbus/__init__.py deleted file mode 100644 index 749c9606..00000000 --- a/chimerapy/engine/eventbus/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .eventbus import Event, EventBus, TypedObserver -from .observables import ObservableDict -from .wrapper import configure, evented, make_evented - -__all__ = [ - "EventBus", - "Event", - "TypedObserver", - "make_evented", - "evented", - "configure", - "ObservableDict", -] diff --git a/chimerapy/engine/eventbus/eventbus.py b/chimerapy/engine/eventbus/eventbus.py deleted file mode 100644 index a85264e3..00000000 --- a/chimerapy/engine/eventbus/eventbus.py +++ /dev/null @@ -1,188 +0,0 @@ -import asyncio -import uuid -from collections import deque -from concurrent.futures import Future -from dataclasses import dataclass, field -from datetime import datetime -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - Literal, - Optional, - Type, - TypeVar, -) - -from aioreactive import AsyncObservable, AsyncObserver, AsyncSubject - -from .. import _logger -from ..networking.async_loop_thread import AsyncLoopThread -from ..utils import future_wrapper - -T = TypeVar("T") - -logger = _logger.getLogger("chimerapy-engine") - - -@dataclass -class Event: - type: str - data: Optional[Any] = None - id: str = field(default_factory=lambda: str(uuid.uuid4())) - timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class EventBus(AsyncObservable): - def __init__(self, thread: Optional[AsyncLoopThread] = None): - self.stream = AsyncSubject() - self._event_counts: int = 0 - self._sub_counts: int = 0 - self.thread = thread - - # State information - self.awaitable_events: Dict[str, asyncio.Event] = {} - self.subscription_map: Dict[AsyncObserver, Any] = {} - - #################################################################### - ## Getters and Setters - #################################################################### - - def set_thread(self, thread: AsyncLoopThread): - self.thread = thread - - #################################################################### - ## Async - #################################################################### - - async def asend(self, event: Event): - # logger.debug(f"EventBus: Sending event: {event}") - self._event_counts += 1 - await self.stream.asend(event) - - if event.type in self.awaitable_events: - self.latest_event = event - self.awaitable_events[event.type].set() - del self.awaitable_events[event.type] - - async def asubscribe(self, observer: AsyncObserver): - self._sub_counts += 1 - subscription = await self.stream.subscribe_async(observer) - self.subscription_map[observer] = subscription - - async def aunsubscribe(self, observer: AsyncObserver): - if observer not in self.subscription_map: - raise RuntimeError( - "Trying to unsubscribe an Observer that is not subscribed" - ) - - self._sub_counts -= 1 - subscription = self.subscription_map[observer] - await subscription.dispose_async() - - async def await_event(self, event_type: str) -> Event: - if event_type not in self.awaitable_events: - self.awaitable_events[event_type] = asyncio.Event() - - event_trigger = self.awaitable_events[event_type] - await event_trigger.wait() - return self.latest_event - - #################################################################### - ## Sync - #################################################################### - - def send( - self, event: Event, loop: Optional[asyncio.AbstractEventLoop] = None - ) -> Future: - if isinstance(self.thread, AsyncLoopThread): - return self.thread.exec(self.asend(event)) - else: - if loop is None: - loop = asyncio.get_event_loop() - wrapper, future = future_wrapper(self.asend(event)) - loop.create_task(wrapper) - return future - - def subscribe(self, observer: AsyncObserver) -> Future: - if isinstance(self.thread, AsyncLoopThread): - return self.thread.exec(self.asubscribe(observer)) - else: - loop = asyncio.get_event_loop() - wrapper, future = future_wrapper(self.asubscribe(observer)) - loop.create_task(wrapper) - return future - - -class TypedObserver(AsyncObserver, Generic[T]): - def __init__( - self, - event_type: str, - event_data_cls: Optional[Type[T]] = None, - on_asend: Optional[Callable] = None, - on_athrow: Optional[Callable] = None, - on_aclose: Optional[Callable] = None, - handle_event: Literal["pass", "unpack", "drop"] = "pass", - ): - - # Containers - self.event_type = event_type - self.event_data_cls = event_data_cls - self.handle_event = handle_event - self.received: deque[str] = deque(maxlen=10) - - # Callables - self._on_asend = on_asend - self._on_athrow = on_athrow - self._on_aclose = on_aclose - - def __str__(self) -> str: - string = f"" - return string - - def bind_asend(self, func: Callable[[Event], Awaitable[None]]): - self._on_asend = func - - def bind_athrow(self, func: Callable[[Exception], Awaitable[None]]): - self._on_athrow = func - - def bind_aclose(self, func: Callable[[], Awaitable[None]]): - self._on_aclose = func - - async def exec_callable(self, func: Callable, *arg, **kwargs): - if asyncio.iscoroutinefunction(func): - await func(*arg, **kwargs) - else: - func(*arg, **kwargs) - - async def asend(self, event: Event): - - if self.event_data_cls is None: - is_match = event.type == self.event_type - else: - is_match = ( - isinstance(event.data, self.event_data_cls) - and event.type == self.event_type - ) - - if is_match: - # logger.debug(f"{self}: asend!") - self.received.append(event.id) - if self._on_asend: - if self.handle_event == "pass": - await self.exec_callable(self._on_asend, event) - elif self.handle_event == "unpack": - await self.exec_callable(self._on_asend, **event.data.__dict__) - elif self.handle_event == "drop": - await self.exec_callable(self._on_asend) - - async def athrow(self, ex: Exception): - if self._on_athrow: - await self.exec_callable(self._on_athrow, ex) - - async def aclose(self): - if self._on_aclose: - await self.exec_callable(self._on_aclose) diff --git a/chimerapy/engine/eventbus/observables.py b/chimerapy/engine/eventbus/observables.py deleted file mode 100644 index 1dc9d556..00000000 --- a/chimerapy/engine/eventbus/observables.py +++ /dev/null @@ -1,65 +0,0 @@ -class ObservableDict(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.callback = None # initialize the callback - - def __setitem__(self, key, value): - super().__setitem__(key, value) - if self.callback: - self.callback(key, value) - - def __delitem__(self, key): - super().__delitem__(key) - if self.callback: - self.callback(key, None) - - def set_callback(self, callback): - self.callback = callback - - -class ObservableList(list): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.callback = None # initialize the callback - - def __setitem__(self, index, value): - super().__setitem__(index, value) - if self.callback: - self.callback(index, value) - - def __delitem__(self, index): - super().__delitem__(index) - if self.callback: - self.callback(index, None) - - def append(self, value): - super().append(value) - if self.callback: - self.callback(len(self) - 1, value) - - def extend(self, iterable): - start_len = len(self) - super().extend(iterable) - if self.callback: - for i, item in enumerate(iterable, start=start_len): - self.callback(i, item) - - def insert(self, index, value): - super().insert(index, value) - if self.callback: - self.callback(index, value) - - def remove(self, value): - index = self.index(value) - super().remove(value) - if self.callback: - self.callback(index, None) - - def pop(self, index=-1): - value = super().pop(index) - if self.callback: - self.callback(index, None) - return value - - def set_callback(self, callback): - self.callback = callback diff --git a/chimerapy/engine/eventbus/wrapper.py b/chimerapy/engine/eventbus/wrapper.py deleted file mode 100644 index 37232362..00000000 --- a/chimerapy/engine/eventbus/wrapper.py +++ /dev/null @@ -1,132 +0,0 @@ -from dataclasses import dataclass, fields, is_dataclass -from typing import Any, Optional, TypeVar - -import dataclasses_json - -from .eventbus import Event, EventBus -from .observables import ObservableDict, ObservableList - -T = TypeVar("T") - -# Global variables -global_event_bus: Optional["EventBus"] = None - - -@dataclass -class DataClassEvent: - dataclass: Any - - -def configure(event_bus: EventBus): - global global_event_bus - global_event_bus = event_bus - - -def evented(cls): - original_init = cls.__init__ - - def new_init(self, *args, **kwargs): - global global_event_bus - - self.event_bus = None - - if isinstance(global_event_bus, EventBus): - self.event_bus = global_event_bus - - original_init(self, *args, **kwargs) - - def make_property(name: str) -> Any: - def getter(self): - return self.__dict__[f"_{name}"] - - def setter(self, value): - self.__dict__[f"_{name}"] = value - if self.event_bus: - event_name = f"{cls.__name__}.changed" - event_data = DataClassEvent(self) - self.event_bus.send(Event(event_name, event_data)) - - return property(getter, setter) - - cls.__init__ = new_init - - for f in fields(cls): - if f.name != "event_bus": - setattr(cls, f.name, make_property(f.name)) - - setattr( - cls, - "event_bus", - dataclasses_json.config( - field_name="event_bus", encoder=lambda x: None, decoder=lambda x: None - ), - ) - - return cls - - -def make_evented( - instance: T, - event_bus: "EventBus", - event_name: Optional[str] = None, - object: Optional[Any] = None, -) -> T: - setattr(instance, "event_bus", event_bus) - instance.__evented_values = {} # type: ignore[attr-defined] - - # Name of the event - if not event_name: - event_name = f"{instance.__class__.__name__}.changed" - - # Dynamically create a new class with the same name as the instance's class - new_class_name = instance.__class__.__name__ - NewClass = type(new_class_name, (instance.__class__,), {}) - - def make_property(name: str): - def getter(self): - return self.__evented_values.get(name) - - def setter(self, value): - self.__evented_values[name] = value - if object: - event_data = DataClassEvent(object) - else: - event_data = DataClassEvent(self) - - event_bus.send(Event(event_name, event_data)) - - return property(getter, setter) - - def callback(key, value): - if object: - event_data = DataClassEvent(object) - else: - event_data = DataClassEvent(instance) - - event_bus.send(Event(event_name, event_data)) - - for f in fields(instance.__class__): - if f.name != "event_bus": - attr_value = getattr(instance, f.name) - - # Check if other dataclass - if is_dataclass(attr_value): - attr_value = make_evented(attr_value, event_bus, event_name, instance) - - # If the attribute is a dictionary, replace it with an ObservableDict - elif isinstance(attr_value, dict): - attr_value = ObservableDict(attr_value) - attr_value.set_callback(callback) - - # Handle list - elif isinstance(attr_value, list): - attr_value = ObservableList(attr_value) - attr_value.set_callback(callback) - - instance.__evented_values[f.name] = attr_value # type: ignore[attr-defined] - setattr(NewClass, f.name, make_property(f.name)) - - # Change the class of the instance - instance.__class__ = NewClass - - return instance diff --git a/chimerapy/engine/manager/distributed_logging_service.py b/chimerapy/engine/manager/distributed_logging_service.py index b7229cd4..6940c955 100644 --- a/chimerapy/engine/manager/distributed_logging_service.py +++ b/chimerapy/engine/manager/distributed_logging_service.py @@ -1,13 +1,13 @@ from typing import Dict, Optional +from aiodistbus import registry + from chimerapy.engine import _logger, config -from ..eventbus import EventBus, TypedObserver from ..logger.distributed_logs_sink import DistributedLogsMultiplexedFileSink from ..service import Service -from ..states import ManagerState +from ..states import ManagerState, WorkerState from ..utils import megabytes_to_bytes -from .events import DeregisterEntityEvent, RegisterEntityEvent logger = _logger.getLogger("chimerapy-engine") @@ -17,7 +17,6 @@ def __init__( self, name: str, publish_logs_via_zmq: bool, - eventbus: EventBus, state: ManagerState, **kwargs, ): @@ -26,37 +25,13 @@ def __init__( # Save parameters self.name = name self.logs_sink: Optional[DistributedLogsMultiplexedFileSink] = None - self.eventbus = eventbus self.state = state if publish_logs_via_zmq: handler_config = _logger.ZMQLogHandlerConfig.from_dict(kwargs) _logger.add_zmq_handler(logger, handler_config) - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "entity_register": TypedObserver( - "entity_register", - on_asend=self.register_entity, - event_data_cls=RegisterEntityEvent, - handle_event="unpack", - ), - "entity_deregister": TypedObserver( - "entity_deregister", - on_asend=self.deregister_entity, - event_data_cls=DeregisterEntityEvent, - handle_event="unpack", - ), - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("start", namespace=f"{__name__}.DistributedLoggingService") def start(self): if config.get("manager.logs-sink.enabled"): @@ -65,6 +40,7 @@ def start(self): else: self.logs_sink = None + @registry.on("shutdown", namespace=f"{__name__}.DistributedLoggingService") def shutdown(self): if self.logs_sink: @@ -74,15 +50,22 @@ def shutdown(self): ## Helper Function ##################################################################################### - def register_entity(self, worker_name: str, worker_id: str): + @registry.on( + "entity_register", + WorkerState, + namespace=f"{__name__}.DistributedLoggingService", + ) + def register_entity(self, state: WorkerState): # logger.debug(f"{self}: registereing entity: {worker_name}, {worker_id}") + id, name = state.id, state.name if self.logs_sink is not None: - self._register_worker_to_logs_sink( - worker_name=worker_name, worker_id=worker_id - ) + self._register_worker_to_logs_sink(worker_name=name, worker_id=id) + @registry.on( + "entity_deregister", str, namespace=f"{__name__}.DistributedLoggingService" + ) def deregister_entity(self, worker_id: str): if self.logs_sink is not None: diff --git a/chimerapy/engine/manager/events.py b/chimerapy/engine/manager/events.py deleted file mode 100644 index e85bf8b2..00000000 --- a/chimerapy/engine/manager/events.py +++ /dev/null @@ -1,40 +0,0 @@ -from dataclasses import dataclass - -from ..states import WorkerState - - -@dataclass -class StartEvent: - ... - - -@dataclass -class UpdateSendArchiveEvent: # update_send_archive - worker_id: str - success: bool - - -@dataclass -class WorkerRegisterEvent: # worker_register - worker_state: WorkerState - - -@dataclass -class WorkerDeregisterEvent: # worker_deregister - worker_state: WorkerState - - -@dataclass -class RegisterEntityEvent: # entity_register - worker_name: str - worker_id: str - - -@dataclass -class DeregisterEntityEvent: # entity_deregister - worker_id: str - - -@dataclass -class MoveTransferredFilesEvent: # move_transferred_files - worker_state: WorkerState diff --git a/chimerapy/engine/manager/http_server_service.py b/chimerapy/engine/manager/http_server_service.py index 1d1557bb..d6e74d83 100644 --- a/chimerapy/engine/manager/http_server_service.py +++ b/chimerapy/engine/manager/http_server_service.py @@ -1,23 +1,18 @@ import traceback from concurrent.futures import Future -from typing import Dict, List +from typing import List +from aiodistbus import registry from aiohttp import web from chimerapy.engine import _logger, config -from ..eventbus import Event, EventBus, TypedObserver +from ..data_protocols import UpdateSendArchiveData from ..networking import Server from ..networking.enums import MANAGER_MESSAGE from ..service import Service from ..states import ManagerState, WorkerState from ..utils import update_dataclass -from .events import ( - MoveTransferredFilesEvent, - UpdateSendArchiveEvent, - WorkerDeregisterEvent, - WorkerRegisterEvent, -) logger = _logger.getLogger("chimerapy-engine") @@ -28,7 +23,6 @@ def __init__( name: str, port: int, enable_api: bool, - eventbus: EventBus, state: ManagerState, ): super().__init__(name=name) @@ -38,7 +32,6 @@ def __init__( self._ip = "172.0.0.1" self._port = port self._enable_api = enable_api - self.eventbus = eventbus self.state = state # Future Container @@ -58,29 +51,6 @@ def __init__( ], ) - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "ManagerState.changed": TypedObserver( - "ManagerState.changed", - on_asend=self._broadcast_network_status_update, - handle_event="drop", - ), - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - "move_transferred_files": TypedObserver( - "move_transferred_files", - MoveTransferredFilesEvent, - on_asend=self.move_transferred_files, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - @property def ip(self) -> str: return self._ip @@ -93,6 +63,7 @@ def port(self) -> int: def url(self) -> str: return f"http://{self._ip}:{self._port}" + @registry.on("start", namespace=f"{__name__}.HttpServerService") async def start(self): # Runn the Server @@ -104,8 +75,9 @@ async def start(self): self.state.port = self.port # After updatign the information, then run it! - await self.eventbus.asend(Event("after_server_startup")) + await self.entrypoint.emit("after_server_startup") + @registry.on("shutdown", namespace=f"{__name__}.HttpServerService") async def shutdown(self) -> bool: # Finish any other tasks @@ -126,6 +98,9 @@ def _future_flush(self): except Exception: logger.error(traceback.format_exc()) + @registry.on( + "move_transferred_files", WorkerState, namespace=f"{__name__}.HttpServerService" + ) async def move_transferred_files(self, worker_state: WorkerState) -> bool: return await self._server.move_transferred_files( self.state.logdir, owner=worker_state.name, owner_id=worker_state.id @@ -147,9 +122,7 @@ async def _register_worker_route(self, request: web.Request): worker_state = WorkerState.from_dict(msg) # Register worker - await self.eventbus.asend( - Event("worker_register", WorkerRegisterEvent(worker_state)) - ) + await self.entrypoint.emit("worker_register", worker_state) response = { "logs_push_info": { @@ -168,9 +141,7 @@ async def _deregister_worker_route(self, request: web.Request): worker_state = WorkerState.from_dict(msg) # Deregister worker - await self.eventbus.asend( - Event("worker_deregister", WorkerDeregisterEvent(worker_state)) - ) + await self.entrypoint.emit("worker_deregister", worker_state) return web.HTTPOk() @@ -188,14 +159,15 @@ async def _update_nodes_status(self, request: web.Request): async def _update_send_archive(self, request: web.Request): msg = await request.json() - event_data = UpdateSendArchiveEvent(**msg) - await self.eventbus.asend(Event("update_send_archive", event_data)) + event_data = UpdateSendArchiveData(**msg) + await self.entrypoint.emit("update_send_archive", event_data) return web.HTTPOk() ##################################################################################### ## Front-End API ##################################################################################### + @registry.on("ManagerState.changed", namespace=f"{__name__}.HttpServerService") async def _broadcast_network_status_update(self): if not self._enable_api: diff --git a/chimerapy/engine/manager/manager.py b/chimerapy/engine/manager/manager.py index 2b9b6552..b05150ea 100644 --- a/chimerapy/engine/manager/manager.py +++ b/chimerapy/engine/manager/manager.py @@ -6,18 +6,19 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union import asyncio_atexit +from aiodistbus import EntryPoint, EventBus, make_evented, registry # Internal Imports from chimerapy.engine import _logger, config -from chimerapy.engine.graph import Graph -from chimerapy.engine.states import ManagerState, WorkerState -# Eventbus -from ..eventbus import Event, EventBus, make_evented +from ..data_protocols import CommitData, RegisteredMethodData +from ..graph import Graph from ..networking.async_loop_thread import AsyncLoopThread -from .distributed_logging_service import DistributedLoggingService +from ..service import Service +from ..states import ManagerState, WorkerState # Services +from .distributed_logging_service import DistributedLoggingService from .http_server_service import HttpServerService from .session_record_service import SessionRecordService from .worker_handler_service import WorkerHandlerService @@ -27,13 +28,6 @@ class Manager: - - http_server: HttpServerService - worker_handler: WorkerHandlerService - zeroconf_service: ZeroconfService - session_record: SessionRecordService - distributed_logging: DistributedLoggingService - def __init__( self, logdir: Union[pathlib.Path, str], @@ -68,6 +62,7 @@ def __init__( # Creating a container for task futures self.task_futures: List[Future] = [] + self.services: List[Service] = [] # Create log directory to store data self.timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") @@ -81,45 +76,44 @@ def __init__( async def aserve(self) -> bool: # Create eventbus - self.eventbus = EventBus() - self.state = make_evented(self.state, event_bus=self.eventbus) + self.bus = EventBus() + self.entrypoint = EntryPoint() + await self.entrypoint.connect(self.bus) + self.state = make_evented(self.state, bus=self.bus) # Create the services - self.http_server = HttpServerService( - name="http_server", - port=self.state.port, - enable_api=self.enable_api, - eventbus=self.eventbus, - state=self.state, - ) - self.worker_handler = WorkerHandlerService( - name="worker_handler", eventbus=self.eventbus, state=self.state + self.services.append( + HttpServerService( + name="http_server", + port=self.state.port, + enable_api=self.enable_api, + state=self.state, + ) ) - self.zeroconf_service = ZeroconfService( - name="zeroconf", eventbus=self.eventbus, state=self.state + self.services.append( + WorkerHandlerService(name="worker_handler", state=self.state) ) - self.session_record = SessionRecordService( - name="session_record", - eventbus=self.eventbus, - state=self.state, + self.services.append(ZeroconfService(name="zeroconf", state=self.state)) + self.services.append( + SessionRecordService( + name="session_record", + state=self.state, + ) ) - self.distributed_logging = DistributedLoggingService( - name="distributed_logging", - publish_logs_via_zmq=self.publish_logs_via_zmq, - eventbus=self.eventbus, - state=self.state, - # **self.kwargs, + self.services.append( + DistributedLoggingService( + name="distributed_logging", + publish_logs_via_zmq=self.publish_logs_via_zmq, + state=self.state, + # **self.kwargs, + ) ) - # Initialize services - await self.http_server.async_init() - await self.worker_handler.async_init() - await self.zeroconf_service.async_init() - await self.session_record.async_init() - await self.distributed_logging.async_init() + for service in self.services: + await service.attach(self.bus) # Start all services - await self.eventbus.asend(Event("start")) + await self.entrypoint.emit("start") # Logging logger.info(f"ChimeraPy: Manager running at {self.host}:{self.port}") @@ -176,97 +170,29 @@ def _exec_coro(self, coro: Coroutine) -> Future: return future - #################################################################### - ## Async Networking - #################################################################### + async def _register_graph(self, graph: Graph): + await self.entrypoint.emit("register_graph", graph) - async def _async_request_node_creation( - self, - worker_id: str, - node_id: str, - context: Literal["multiprocessing", "threading"] = "multiprocessing", - ) -> bool: - return await self.worker_handler._request_node_creation( - worker_id, node_id, context=context - ) - - async def _async_request_node_destruction( - self, worker_id: str, node_id: str - ) -> bool: - return await self.worker_handler._request_node_destruction(worker_id, node_id) - - async def _async_request_node_pub_table(self, worker_id: str) -> bool: - return await self.worker_handler._request_node_pub_table(worker_id) - - async def _async_request_connection_creation(self, worker_id: str) -> bool: - return await self.worker_handler._request_connection_creation(worker_id) - - async def _async_broadcast_request( - self, - htype: Literal["get", "post"], - route: str, - data: Any = {}, - timeout: Optional[Union[int, float]] = config.get( - "manager.timeout.info-request" - ), - report_exceptions: bool = True, - ) -> bool: - return await self.worker_handler._broadcast_request( - htype, route, data, timeout, report_exceptions - ) + async def _deregister_graph(self): + await self.entrypoint.emit("deregister_graph") #################################################################### - ## Sync Networking + ## EventListeners #################################################################### - def _register_graph(self, graph: Graph): - self.worker_handler._register_graph(graph) - - def _deregister_graph(self): - self.worker_handler._deregister_graph() - - def _request_node_creation( - self, - worker_id: str, - node_id: str, - context: Literal["multiprocessing", "threading"] = "multiprocessing", - ) -> Future[bool]: - return self._exec_coro( - self._async_request_node_creation(worker_id, node_id, context=context) - ) - - def _request_node_destruction(self, worker_id: str, node_id: str) -> Future[bool]: - return self._exec_coro(self._async_request_node_destruction(worker_id, node_id)) - - def _request_node_pub_table(self, worker_id: str) -> Future[bool]: - return self._exec_coro(self._async_request_node_pub_table(worker_id)) - - def _request_connection_creation(self, worker_id: str) -> Future[bool]: - return self._exec_coro(self._async_request_connection_creation(worker_id)) - - def _broadcast_request( - self, - htype: Literal["get", "post"], - route: str, - data: Any = {}, - timeout: Union[int, float] = config.get("manager.timeout.info-request"), - ) -> Future[bool]: - return self._exec_coro( - self._async_broadcast_request(htype, route, data, timeout) - ) + @registry.on("registered_method_rep", namespace=f"{__name__}.Manager") + async def registered_method_rep(self): + ... #################################################################### ## Front-facing ASync API #################################################################### - async def async_zeroconf(self, enable: bool = True) -> bool: - if enable: - return await self.zeroconf_service.enable() - else: - return await self.zeroconf_service.disable() + async def async_zeroconf(self, enable: bool = True): + await self.entrypoint.emit("zeroconf", enable) - async def async_diagnostics(self, enable: bool = True) -> bool: - return await self.worker_handler.diagnostics(enable) + async def async_diagnostics(self, enable: bool = True): + await self.entrypoint.emit("diagnostics", enable) async def async_commit( self, @@ -274,7 +200,7 @@ async def async_commit( mapping: Dict[str, List[str]], context: Literal["multiprocessing", "threading"] = "multiprocessing", send_packages: Optional[List[Dict[str, Any]]] = None, - ) -> bool: + ): """Committing ``Graph`` to the cluster. Committing refers to how the graph itself (with its nodes and edges) @@ -306,38 +232,39 @@ async def async_commit( via dictionary with the following key-value pairs: \ name:``str`` and path:``pathlit.Path``. - Returns: - bool: Success in cluster's setup - """ - return await self.worker_handler.commit( - graph, mapping, context=context, send_packages=send_packages - ) + commit_data = CommitData(graph, mapping, context, send_packages) + await self.entrypoint.emit("commit", commit_data) async def async_gather(self) -> Dict: - return await self.worker_handler.gather() + # TODO + await self.entrypoint.emit("gather") + return {} - async def async_start(self) -> bool: - return await self.worker_handler.start_workers() + async def async_start(self): + await self.entrypoint.emit("start") - async def async_record(self) -> bool: - return await self.worker_handler.record() + async def async_record(self): + await self.entrypoint.emit("record") async def async_request_registered_method( self, node_id: str, method_name: str, params: Dict[str, Any] = {} ) -> Dict[str, Any]: - return await self.worker_handler.request_registered_method( - node_id, method_name, params + reg_method_data = RegisteredMethodData( + node_id=node_id, method_name=method_name, params=params ) + await self.entrypoint.emit("request_registered_method", reg_method_data) + # TODO + return {} - async def async_stop(self) -> bool: - return await self.worker_handler.stop() + async def async_stop(self): + await self.entrypoint.emit("stop") - async def async_collect(self) -> bool: - return await self.worker_handler.collect() + async def async_collect(self): + await self.entrypoint.emit("collect") async def async_reset(self, keep_workers: bool = True): - return await self.worker_handler.reset(keep_workers) + await self.entrypoint.emit("reset", keep_workers) async def async_shutdown(self) -> bool: @@ -346,7 +273,7 @@ async def async_shutdown(self) -> bool: # logger.debug(f"{self}: requested to shutdown twice, skipping.") return True - await self.eventbus.asend(Event("shutdown")) + await self.entrypoint.emit("shutdown") self.has_shutdown = True return True @@ -409,19 +336,6 @@ def commit_graph( ) ) - def step(self) -> Future[bool]: - """Cluster step execution for offline operation. - - The ``step`` function is for careful but slow operation of the - cluster. For online execution, ``start`` and ``stop`` are the - methods to be used. - - Returns: - Future[bool]: Future of the success of step function broadcasting - - """ - return self._exec_coro(self._async_broadcast_request("post", "/nodes/step")) - def gather(self) -> Future[Dict]: return self._exec_coro(self.async_gather()) diff --git a/chimerapy/engine/manager/session_record_service.py b/chimerapy/engine/manager/session_record_service.py index 7c596e93..15c542e2 100644 --- a/chimerapy/engine/manager/session_record_service.py +++ b/chimerapy/engine/manager/session_record_service.py @@ -2,7 +2,8 @@ import json from typing import Dict, Optional -from ..eventbus import EventBus, TypedObserver +from aiodistbus import registry + from ..service import Service from ..states import ManagerState @@ -11,13 +12,11 @@ class SessionRecordService(Service): def __init__( self, name: str, - eventbus: EventBus, state: ManagerState, ): super().__init__(name=name) # Input parameters - self.eventbus = eventbus self.state = state # State information @@ -25,23 +24,7 @@ def __init__( self.stop_time: Optional[datetime.datetime] = None self.duration: float = 0 - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "save_meta": TypedObserver( - "save_meta", on_asend=self._save_meta, handle_event="drop" - ), - "start_recording": TypedObserver( - "start_recording", on_asend=self.start_recording, handle_event="drop" - ), - "stop_recording": TypedObserver( - "stop_recording", on_asend=self.stop_recording, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("save_meta", namespace=f"{__name__}.SessionRecordService") def _save_meta(self): # Get the times, handle Optional if self.start_time: @@ -67,11 +50,13 @@ def _save_meta(self): with open(self.state.logdir / "meta.json", "w") as f: json.dump(meta, f, indent=2) + @registry.on("start_recording", namespace=f"{__name__}.SessionRecordService") def start_recording(self): # Mark the start time self.start_time = datetime.datetime.now() + @registry.on("stop_recording", namespace=f"{__name__}.SessionRecordService") def stop_recording(self): # Mark the stop time diff --git a/chimerapy/engine/manager/worker_handler_service.py b/chimerapy/engine/manager/worker_handler_service.py index e70382e2..2235c4bf 100644 --- a/chimerapy/engine/manager/worker_handler_service.py +++ b/chimerapy/engine/manager/worker_handler_service.py @@ -11,11 +11,17 @@ import aiohttp import dill import networkx as nx +from aiodistbus import make_evented, registry from chimerapy.engine import _logger, config -from ..data_protocols import NodePubTable -from ..eventbus import Event, EventBus, TypedObserver, make_evented +from ..data_protocols import ( + CommitData, + NodePubTable, + RegisteredMethodData, + RegisterMethodResponseData, + UpdateSendArchiveData, +) from ..exceptions import CommitGraphError from ..graph import Graph from ..networking import Client, DataChunk @@ -23,25 +29,16 @@ from ..service import Service from ..states import ManagerState, WorkerState from ..utils import async_waiting_for -from .events import ( - DeregisterEntityEvent, - MoveTransferredFilesEvent, - RegisterEntityEvent, - UpdateSendArchiveEvent, - WorkerDeregisterEvent, - WorkerRegisterEvent, -) logger = _logger.getLogger("chimerapy-engine") class WorkerHandlerService(Service): - def __init__(self, name: str, eventbus: EventBus, state: ManagerState): + def __init__(self, name: str, state: ManagerState): super().__init__(name=name) # Parameters self.name = name - self.eventbus = eventbus self.state = state # Containers @@ -56,35 +53,7 @@ def __init__(self, name: str, eventbus: EventBus, state: ManagerState): # Also create a tempfolder to store any miscellaneous files and folders self.tempfolder = pathlib.Path(tempfile.mkdtemp()) - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - "worker_register": TypedObserver( - "worker_register", - on_asend=self._register_worker, - event_data_cls=WorkerRegisterEvent, - handle_event="unpack", - ), - "worker_deregister": TypedObserver( - "worker_deregister", - on_asend=self._deregister_worker, - event_data_cls=WorkerDeregisterEvent, - handle_event="unpack", - ), - "update_send_archive": TypedObserver( - "update_send_archive", - on_asend=self.update_send_archive, - event_data_cls=UpdateSendArchiveEvent, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("shutdown", namespace=f"{__name__}.WorkerHandlerService") async def shutdown(self) -> bool: # If workers are connected, let's notify them that the cluster is @@ -120,10 +89,13 @@ def _get_worker_ip(self, worker_id: str) -> str: worker_info = self.state.workers[worker_id] return f"http://{worker_info.ip}:{worker_info.port}" + @registry.on( + "worker_register", WorkerState, namespace=f"{__name__}.WorkerHandlerService" + ) async def _register_worker(self, worker_state: WorkerState) -> bool: evented_worker_state = make_evented( - worker_state, event_bus=self.eventbus, event_name="ManagerState.changed" + worker_state, bus=self.entrypoint._bus, event_name="ManagerState.changed" ) self.state.workers[worker_state.id] = evented_worker_state logger.debug( @@ -132,23 +104,16 @@ async def _register_worker(self, worker_state: WorkerState) -> bool: ) # Register entity from logging - await self.eventbus.asend( - Event( - "entity_register", - RegisterEntityEvent( - worker_name=worker_state.name, worker_id=worker_state.id - ), - ) - ) - + await self.entrypoint.emit("entity_register", worker_state) return True + @registry.on( + "worker_deregister", WorkerState, namespace=f"{__name__}.WorkerHandlerService" + ) async def _deregister_worker(self, worker_state: WorkerState) -> bool: # Deregister entity from logging - await self.eventbus.asend( - Event("entity_deregister", DeregisterEntityEvent(worker_id=worker_state.id)) - ) + await self.entrypoint.emit("entity_deregister", worker_state.id) if worker_state.id in self.state.workers: state = self.state.workers[worker_state.id] @@ -162,9 +127,16 @@ async def _deregister_worker(self, worker_state: WorkerState) -> bool: return False - async def update_send_archive(self, worker_id: str, success: bool): + @registry.on( + "update_send_archive", + UpdateSendArchiveData, + namespace=f"{__name__}.WorkerHandlerService", + ) + async def update_send_archive(self, data: UpdateSendArchiveData): + worker_id, success = data.worker_id, data.success self.collected_workers[worker_id] = success + @registry.on("register_graph", Graph, namespace=f"{__name__}.WorkerHandlerService") def _register_graph(self, graph: Graph): """Verifying that a Graph is valid, that is a DAG. @@ -190,6 +162,7 @@ def _register_graph(self, graph: Graph): self.graph.G.nodes[node_id]["object"] ) + @registry.on("deregister_graph", namespace=f"{__name__}.WorkerHandlerService") def _deregister_graph(self): self.graph: Graph = Graph() self.graph_dumps: Dict[str, bytes] = {} @@ -628,16 +601,13 @@ async def _single_worker_collect(self, worker_id: str) -> bool: f"{self}: Collection failed, " "never updated on archival completion" ) + # else: + # logger.debug(f"{self}: Collection success, bool: ## Front-facing ASync API #################################################################### + @registry.on("diagnostics", bool, namespace=f"{__name__}.WorkerHandlerService") async def diagnostics(self, enable: bool = True) -> bool: return await self._broadcast_request( "post", "/nodes/diagnostics", data={"enable": enable} ) - async def commit( - self, - graph: Graph, - mapping: Dict[str, List[str]], - context: Literal["multiprocessing", "threading"] = "multiprocessing", - send_packages: Optional[List[Dict[str, Any]]] = None, - ) -> bool: + @registry.on("commit", CommitData, namespace=f"{__name__}.WorkerHandlerService") + async def commit(self, commit_data: CommitData): """Committing ``Graph`` to the cluster. Committing refers to how the graph itself (with its nodes and edges) @@ -691,31 +657,29 @@ async def commit( via dictionary with the following key-value pairs: \ name:``str`` and path:``pathlit.Path``. - Returns: - bool: Success in cluster's setup - """ # First, test that the graph and the mapping are valid - self._register_graph(graph) - self._map_graph(mapping) - await self.eventbus.asend(Event("save_meta")) + self._register_graph(commit_data.graph) + self._map_graph(commit_data.mapping) + await self.entrypoint.emit("save_meta") # Then send requested packages success = True - if send_packages: - success = await self._distribute_packages(send_packages) + if commit_data.send_packages: + success = await self._distribute_packages(commit_data.send_packages) # If package are sent correctly, try to create network # Start with the nodes and then the connections if ( success - and await self._create_p2p_network(context=context) + and await self._create_p2p_network(context=commit_data.context) and await self._setup_p2p_connections() ): return True return False + @registry.on("gather", namespace=f"{__name__}.WorkerHandlerService") async def gather(self) -> Dict: # Wail until all workers have responded with their node server data gather_data = {} @@ -742,69 +706,76 @@ async def gather(self) -> Dict: return gather_data - async def start_workers(self) -> bool: + @registry.on("start", namespace=f"{__name__}.WorkerHandlerService") + async def start_workers(self): # Tell the cluster to start success = await self._broadcast_request("post", "/nodes/start") # Updating meta just in case of failure - await self.eventbus.asend(Event("save_meta")) + await self.entrypoint.emit("save_meta") return success + @registry.on("record", namespace=f"{__name__}.WorkerHandlerService") async def record(self) -> bool: # Mark the start time - await self.eventbus.asend(Event("start_recording")) + await self.entrypoint.emit("start_recording") # Tell the cluster to start success = await self._broadcast_request("post", "/nodes/record") # Updating meta just in case of failure - await self.eventbus.asend(Event("save_meta")) + await self.entrypoint.emit("save_meta") return success + @registry.on( + "request_registered_method", + RegisteredMethodData, + namespace=f"{__name__}.WorkerHandlerService", + ) async def request_registered_method( - self, node_id: str, method_name: str, params: Dict[str, Any] = {} - ) -> Dict[str, Any]: + self, reg_method_data: RegisteredMethodData + ) -> RegisterMethodResponseData: # First, identify which worker has the node - worker_id = self._node_to_worker_lookup(node_id) + worker_id = self._node_to_worker_lookup(reg_method_data.node_id) if not isinstance(worker_id, str): - return {"success": False, "output": None} - - data = { - "node_id": str(node_id), - "method_name": str(method_name), - "params": dict(params), - } + logger.error(f"{self}: Registered Method for Worker {worker_id}: FAILED") + res_data = RegisterMethodResponseData(success=False, result={}) + return res_data async with self.http_client.post( f"{self._get_worker_ip(worker_id)}/nodes/registered_methods", - data=json.dumps(data), + data=reg_method_data.to_json(), ) as resp: if resp.ok: resp_data = await resp.json() - return resp_data + res_data = RegisterMethodResponseData.from_json(resp_data) + return res_data else: - logger.debug( + logger.error( f"{self}: Registered Method for Worker {worker_id}: FAILED" ) - return {"success": False, "output": None} + res_data = RegisterMethodResponseData(success=False, result={}) + return res_data + @registry.on("stop", namespace=f"{__name__}.WorkerHandlerService") async def stop(self) -> bool: # Mark the start time - await self.eventbus.asend(Event("stop_recording")) + await self.entrypoint.emit("stop_recording") # Tell the cluster to start success = await self._broadcast_request("post", "/nodes/stop") return success + @registry.on("collect", namespace=f"{__name__}.WorkerHandlerService") async def collect(self) -> bool: # Clear @@ -821,9 +792,10 @@ async def collect(self) -> bool: logger.error(traceback.format_exc()) return False - await self.eventbus.asend(Event("save_meta")) + await self.entrypoint.emit("save_meta") return all(results) + @registry.on("reset", bool, namespace=f"{__name__}.WorkerHandlerService") async def reset(self, keep_workers: bool = True): # Destroy Nodes safely diff --git a/chimerapy/engine/manager/zeroconf_service.py b/chimerapy/engine/manager/zeroconf_service.py index 7d003a41..da943d19 100644 --- a/chimerapy/engine/manager/zeroconf_service.py +++ b/chimerapy/engine/manager/zeroconf_service.py @@ -2,11 +2,11 @@ from datetime import datetime from typing import Dict, Optional +from aiodistbus import registry from zeroconf import ServiceInfo, Zeroconf from chimerapy.engine import _logger -from ..eventbus import EventBus, TypedObserver from ..service import Service from ..states import ManagerState @@ -17,32 +17,18 @@ class ZeroconfService(Service): enabled: bool - def __init__(self, name: str, eventbus: EventBus, state: ManagerState): + def __init__(self, name: str, state: ManagerState): super().__init__(name=name) # Save information self.name = name - self.eventbus = eventbus self.state = state # Creating zeroconf variables self.zeroconf: Optional[Zeroconf] = None self.enabled: bool = False - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "after_server_startup": TypedObserver( - "after_server_startup", on_asend=self.start, handle_event="drop" - ), - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("after_server_startup", namespace=f"{__name__}.ZeroconfService") def start(self): # Create the zeroconf service name @@ -60,6 +46,14 @@ def start(self): }, ) + @registry.on("zeroconf", bool, namespace=f"{__name__}.ZeroconfService") + async def zeroconf_update(self, enable: bool): + if enable: + await self.enable() + else: + await self.disable() + + @registry.on("shutdown", namespace=f"{__name__}.ZeroconfService") async def shutdown(self): await self.disable() diff --git a/chimerapy/engine/node/events.py b/chimerapy/engine/node/events.py deleted file mode 100644 index 2f590012..00000000 --- a/chimerapy/engine/node/events.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict - -from ..data_protocols import NodeDiagnostics, NodePubTable -from ..networking.client import Client -from ..networking.data_chunk import DataChunk - - -@dataclass -class EnableDiagnosticsEvent: # enable_diagnostics - enable: bool - - -@dataclass -class NewInBoundDataEvent: - data_chunks: Dict[str, DataChunk] - - -@dataclass -class NewOutBoundDataEvent: - data_chunk: DataChunk - - -@dataclass -class ProcessNodePubTableEvent: - node_pub_table: NodePubTable - - -@dataclass -class RegisteredMethodEvent: - method_name: str - params: Dict[str, Any] - client: Client - - -@dataclass -class GatherEvent: - client: Client - - -@dataclass -class DiagnosticsReportEvent: # diagnostics_report - diagnostics: NodeDiagnostics diff --git a/chimerapy/engine/node/fsm_service.py b/chimerapy/engine/node/fsm_service.py index b64dc0b6..781f5788 100644 --- a/chimerapy/engine/node/fsm_service.py +++ b/chimerapy/engine/node/fsm_service.py @@ -1,70 +1,47 @@ import logging -from typing import Dict -from ..eventbus import EventBus, TypedObserver +from aiodistbus import registry + from ..service import Service from ..states import NodeState class FSMService(Service): - def __init__( - self, name: str, state: NodeState, eventbus: EventBus, logger: logging.Logger - ): + def __init__(self, name: str, state: NodeState, logger: logging.Logger): super().__init__(name=name) # Save params self.state = state - self.eventbus = eventbus self.logger = logger - async def async_init(self): - - # Add observers - self.observers: Dict[str, TypedObserver] = { - "initialize": TypedObserver( - "initialize", on_asend=self.init, handle_event="drop" - ), - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "setup_connections": TypedObserver( - "setup_connections", - on_asend=self.setup_connections, - handle_event="drop", - ), - "record": TypedObserver( - "record", on_asend=self.record, handle_event="drop" - ), - "stop": TypedObserver("stop", on_asend=self.stop, handle_event="drop"), - "collect": TypedObserver( - "collect", on_asend=self.collect, handle_event="drop" - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("initialize", namespace=f"{__name__}.FSMService") async def init(self): self.state.fsm = "INITIALIZED" + @registry.on("setup", namespace=f"{__name__}.FSMService") async def setup(self): self.state.fsm = "READY" + @registry.on("setup_connections", namespace=f"{__name__}.FSMService") async def setup_connections(self): self.state.fsm = "CONNECTED" + @registry.on("start", namespace=f"{__name__}.FSMService") async def start(self): self.state.fsm = "PREVIEWING" + @registry.on("record", namespace=f"{__name__}.FSMService") async def record(self): self.state.fsm = "RECORDING" + @registry.on("stop", namespace=f"{__name__}.FSMService") async def stop(self): self.state.fsm = "STOPPED" + @registry.on("collect", namespace=f"{__name__}.FSMService") async def collect(self): self.state.fsm = "SAVED" + @registry.on("teardown", namespace=f"{__name__}.FSMService") async def teardown(self): self.state.fsm = "SHUTDOWN" diff --git a/chimerapy/engine/node/node.py b/chimerapy/engine/node/node.py index d0ef09d7..e128faea 100644 --- a/chimerapy/engine/node/node.py +++ b/chimerapy/engine/node/node.py @@ -6,22 +6,23 @@ import tempfile import uuid from asyncio import Task -from concurrent.futures import Future -from typing import Any, Coroutine, Dict, List, Literal, Optional, Tuple, Union +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Dict, List, Literal, Optional, Union # Third-party Imports import multiprocess as mp import numpy as np import pandas as pd +from aiodistbus import EntryPoint, EventBus, make_evented # Internal Imports -from chimerapy.engine import _logger, config +from chimerapy.engine import _logger -from ..eventbus import Event, EventBus, make_evented +from ..data_protocols import PreSetupData from ..networking import DataChunk -from ..networking.async_loop_thread import AsyncLoopThread +from ..service import Service from ..states import NodeState -from ..utils import future_wrapper +from ..utils import run_coroutine_in_thread from .fsm_service import FSMService # Service Imports @@ -73,11 +74,15 @@ def __init__( self.debug_port = debug_port # State variables - self._thread: Optional[AsyncLoopThread] = None self._running: Union[bool, mp.Value] = True # type: ignore + self.executor: Optional[ThreadPoolExecutor] = None self.eventloop_future: Optional[Future] = None self.eventloop_task: Optional[Task] = None + self._save_tasks: List[Task] = [] self.task_futures: List[Future] = [] + self.services: List[Service] = [] + self.bus: Optional[EventBus] = None + self.entrypoint = EntryPoint() # Generic Node needs self.logger: logging.Logger = logging.getLogger("chimerapy-engine-node") @@ -85,11 +90,6 @@ def __init__( # Default values self.node_config = NodeConfig() - self.worker_comms: Optional[WorkerCommsService] = None - self.processor: Optional[ProcessorService] = None - self.recorder: Optional[RecordService] = None - self.poller: Optional[PollerService] = None - self.publisher: Optional[PublisherService] = None #################################################################### ## Properties @@ -138,10 +138,17 @@ def get_logger(self) -> logging.Logger: return logger # If worker, add zmq handler - if self.worker_comms or self.debug_port: + if "WorkerCommsService" in self.services and isinstance( + self.services["WorkerCommsService"], WorkerCommsService + ): + worker_comms = self.services["WorkerCommsService"] + else: + worker_comms = None + + if worker_comms or self.debug_port: - if self.worker_comms: - logging_port = self.worker_comms.worker_logging_port + if worker_comms: + logging_port = worker_comms.worker_logging_port elif self.debug_port: logging_port = self.debug_port else: @@ -157,10 +164,6 @@ def get_logger(self) -> logging.Logger: def add_worker_comms(self, worker_comms: WorkerCommsService): - # Store service - self.worker_comms = worker_comms - self.worker_comms - # Add the context information self.node_config = worker_comms.node_config @@ -168,19 +171,8 @@ def add_worker_comms(self, worker_comms: WorkerCommsService): self.state.logdir = worker_comms.worker_logdir / self.state.name os.makedirs(self.state.logdir, exist_ok=True) - def _exec_coro(self, coro: Coroutine) -> Tuple[Future, Optional[Task]]: - if isinstance(self._thread, AsyncLoopThread): - future = self._thread.exec(coro) - task = None - else: - loop = asyncio.get_event_loop() - wrapper, future = future_wrapper(coro) - task = loop.create_task(wrapper) - - # Saving the future for later use - self.task_futures.append(future) - - return future, task + # Store service + self.services.append(worker_comms) #################################################################### ## Saving Data Stream API @@ -188,24 +180,18 @@ def _exec_coro(self, coro: Coroutine) -> Tuple[Future, Optional[Task]]: def save_video(self, name: str, data: np.ndarray, fps: int): - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - timestamp = datetime.datetime.now() - video_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "video", - "fps": fps, - "timestamp": timestamp, - } - self.recorder.submit(video_entry) + timestamp = datetime.datetime.now() + video_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "video", + "fps": fps, + "timestamp": timestamp, + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", video_entry)) + ) def save_audio( self, name: str, data: np.ndarray, channels: int, format: int, rate: int @@ -230,26 +216,20 @@ def save_audio( It is the implementation's responsibility to properly format the data """ - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - audio_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "audio", - "channels": channels, - "format": format, - "rate": rate, - "recorder_version": 1, - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(audio_entry) + audio_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "audio", + "channels": channels, + "format": format, + "rate": rate, + "recorder_version": 1, + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", audio_entry)) + ) def save_audio_v2( self, @@ -277,65 +257,48 @@ def save_audio_v2( nframes : int Number of frames. """ - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return - - if self.recorder.enabled: - audio_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "audio", - "channels": channels, - "sampwidth": sampwidth, - "framerate": framerate, - "nframes": nframes, - "recorder_version": 2, - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(audio_entry) + + audio_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "audio", + "channels": channels, + "sampwidth": sampwidth, + "framerate": framerate, + "nframes": nframes, + "recorder_version": 2, + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", audio_entry)) + ) def save_tabular( self, name: str, data: Union[pd.DataFrame, Dict[str, Any], pd.Series] ): - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - tabular_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "tabular", - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(tabular_entry) + tabular_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "tabular", + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", tabular_entry)) + ) def save_image(self, name: str, data: np.ndarray): - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - image_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "image", - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(image_entry) + image_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "image", + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", image_entry)) + ) def save_json(self, name: str, data: Dict[Any, Any]): """Record json data from the node to a JSON Lines file. @@ -353,23 +316,16 @@ def save_json(self, name: str, data: Dict[Any, Any]): The data is recorded in JSON Lines format, which is a sequence of JSON objects. The data dictionary provided must be JSON serializable. """ - - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - json_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "dtype": "json", - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(json_entry) + json_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "dtype": "json", + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", json_entry)) + ) def save_text(self, name: str, data: str, suffix="txt"): """Record text data from the node to a text file. @@ -390,23 +346,17 @@ def save_text(self, name: str, data: str, suffix="txt"): It should be noted that new lines addition should be taken by the callee. """ - if not self.recorder: - self.logger.warning( - f"{self}: cannot perform recording operation without RecorderService " - "initialization" - ) - return False - - if self.recorder.enabled: - text_entry = { - "uuid": uuid.uuid4(), - "name": name, - "data": data, - "suffix": suffix, - "dtype": "text", - "timestamp": datetime.datetime.now(), - } - self.recorder.submit(text_entry) + text_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "suffix": suffix, + "dtype": "text", + "timestamp": datetime.datetime.now(), + } + self._save_tasks.append( + self.loop.create_task(self.entrypoint.emit("record_entry", text_entry)) + ) #################################################################### ## Back-End Lifecycle API @@ -414,27 +364,11 @@ def save_text(self, name: str, data: str, suffix="txt"): async def _setup(self): - # Adding state to the WorkerCommsService - if self.worker_comms: - self.worker_comms.in_node_config( - state=self.state, eventbus=self.eventbus, logger=self.logger - ) - if self.worker_comms.worker_config: - config.update_defaults(self.worker_comms.worker_config) - elif not self.state.logdir: - self.state.logdir = pathlib.Path(tempfile.mktemp()) - - # Create the directory - if self.state.logdir: - os.makedirs(self.state.logdir, exist_ok=True) - else: - raise RuntimeError(f"{self}: logdir {self.state.logdir} not set!") - # Make the state evented - self.state = make_evented(self.state, event_bus=self.eventbus) + self.state = make_evented(self.state, bus=self.bus) # Add the FSM service - self.fsm_service = FSMService("fsm", self.state, self.eventbus, self.logger) + self.services.append(FSMService("fsm", self.state, self.logger)) # Configure the processor's operational mode mode: Literal["main", "step"] = "step" # default @@ -460,72 +394,79 @@ async def _setup(self): in_bound_data = len(self.node_config.in_bound) != 0 # Create services - self.processor = ProcessorService( - name="processor", - setup_fn=self.setup, - main_fn=main_fn, - teardown_fn=self.teardown, - operation_mode=mode, - registered_methods=self.registered_methods, - registered_node_fns=registered_fns, - state=self.state, - eventbus=self.eventbus, - in_bound_data=in_bound_data, - logger=self.logger, + self.services.append( + ProcessorService( + name="processor", + setup_fn=self.setup, + main_fn=main_fn, + teardown_fn=self.teardown, + operation_mode=mode, + registered_methods=self.registered_methods, + registered_node_fns=registered_fns, + state=self.state, + in_bound_data=in_bound_data, + logger=self.logger, + ) ) - self.recorder = RecordService( - name="recorder", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, + self.services.append( + RecordService( + name="recorder", + state=self.state, + logger=self.logger, + ) ) - self.profiler = ProfilerService( - name="profiler", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, + self.services.append( + ProfilerService( + name="profiler", + state=self.state, + logger=self.logger, + ) ) # If in-bound, enable the poller service if self.node_config and self.node_config.in_bound: - self.poller = PollerService( - name="poller", - in_bound=self.node_config.in_bound, - in_bound_by_name=self.node_config.in_bound_by_name, - follow=self.node_config.follow, - state=self.state, - eventbus=self.eventbus, - logger=self.logger, + self.services.append( + PollerService( + name="poller", + in_bound=self.node_config.in_bound, + in_bound_by_name=self.node_config.in_bound_by_name, + follow=self.node_config.follow, + state=self.state, + logger=self.logger, + ) ) # If out_bound, enable the publisher service if self.node_config and self.node_config.out_bound: - self.publisher = PublisherService( - "publisher", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, + self.services.append( + PublisherService( + "publisher", + state=self.state, + logger=self.logger, + ) ) # Initialize all services - if self.worker_comms: - await self.worker_comms.async_init() - if self.poller: - await self.poller.async_init() - if self.publisher: - await self.publisher.async_init() - - await self.processor.async_init() - await self.recorder.async_init() - await self.profiler.async_init() - await self.fsm_service.async_init() + for s in self.services: + await s.attach(self.bus) + + # Presetup + data = PreSetupData(self.state, self.logger) + await self.entrypoint.emit("pre_setup", data) + + # Create the directory + if not self.state.logdir: + self.state.logdir = pathlib.Path(tempfile.mktemp()) + os.makedirs(self.state.logdir, exist_ok=True) # Start all services - await self.eventbus.asend(Event("setup")) + await self.entrypoint.emit("setup") + # self.logger.debug(f"{self}: setup complete") async def _eventloop(self): # self.logger.debug(f"{self}: within event loop") await self._idle() # stop, running, and collecting + # self.logger.debug(f"{self}: after idle") await self._teardown() # self.logger.debug(f"{self}: exiting") return 1 @@ -535,7 +476,8 @@ async def _idle(self): await asyncio.sleep(0.2) async def _teardown(self): - await self.eventbus.asend(Event("teardown")) + await asyncio.gather(*self._save_tasks) + await self.entrypoint.emit("teardown") #################################################################### ## Front-facing Node Lifecycle API @@ -606,7 +548,7 @@ def teardown(self): def run( self, - eventbus: Optional[EventBus] = None, + bus: Optional[EventBus] = None, running: Optional[mp.Value] = None, # type: ignore ): """The actual method that is executed in the new process. @@ -625,33 +567,28 @@ def run( if type(running) != type(None): self._running = running - # Start an async loop - self._thread = AsyncLoopThread() - self._thread.start() - - if not eventbus: - self.eventbus = EventBus(thread=self._thread) - else: - self.eventbus = eventbus - self.eventbus.set_thread(self._thread) - # Have to run setup before letting the system continue - self.eventloop_future, _ = self._exec_coro(self.arun(self.eventbus)) - return 1 + run_func = lambda x: run_coroutine_in_thread(self.arun(x)) + self.executor = ThreadPoolExecutor(max_workers=1) + self.eventloop_future = self.executor.submit(run_func, bus) + return self.eventloop_future.result() - async def arun(self, eventbus: Optional[EventBus] = None): + async def arun(self, bus: Optional[EventBus] = None): self.logger = self.get_logger() self.logger.setLevel(self.logging_level) + self.loop = asyncio.get_event_loop() # Save parameters - if eventbus: - self.eventbus = eventbus + if bus: + self.bus = bus else: - self.eventbus = EventBus() + self.bus = EventBus() + + # Create an entrypoint + await self.entrypoint.connect(self.bus) await self._setup() - self.eventloop_task = asyncio.create_task(self._eventloop()) - return 1 + return await self._eventloop() def shutdown(self, timeout: Optional[Union[float, int]] = None): self.running = False @@ -663,4 +600,5 @@ def shutdown(self, timeout: Optional[Union[float, int]] = None): async def ashutdown(self): self.running = False - await self.eventloop_task + if self.eventloop_task: + await self.eventloop_task diff --git a/chimerapy/engine/node/poller_service.py b/chimerapy/engine/node/poller_service.py index 0373023d..39102773 100644 --- a/chimerapy/engine/node/poller_service.py +++ b/chimerapy/engine/node/poller_service.py @@ -2,14 +2,14 @@ import logging from typing import Dict, List, Optional +from aiodistbus import EntryPoint, EventBus, registry + from chimerapy.engine import _logger from ..data_protocols import NodePubEntry, NodePubTable -from ..eventbus import Event, EventBus, TypedObserver from ..networking import DataChunk, Subscriber from ..service import Service from ..states import NodeState -from .events import NewInBoundDataEvent, ProcessNodePubTableEvent class PollerService(Service): @@ -19,7 +19,6 @@ def __init__( in_bound: List[str], in_bound_by_name: List[str], state: NodeState, - eventbus: EventBus, follow: Optional[str] = None, logger: Optional[logging.Logger] = None, ): @@ -30,7 +29,6 @@ def __init__( self.in_bound_by_name: List[str] = in_bound_by_name self.follow: Optional[str] = follow self.state = state - self.eventbus = eventbus # Logging if logger: @@ -39,30 +37,15 @@ def __init__( self.logger = _logger.getLogger("chimerapy-engine") # Containers + self.emit_counter: int = 0 self.sub: Optional[Subscriber] = None self.in_bound_data: Dict[str, DataChunk] = {} - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - "setup_connections": TypedObserver( - "setup_connections", - ProcessNodePubTableEvent, - on_asend=self.setup_connections, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - #################################################################### ## Lifecycle Hooks #################################################################### + @registry.on("teardown", namespace=f"{__name__}.PollerService") async def teardown(self): # Shutting down subscriber @@ -73,6 +56,9 @@ async def teardown(self): ## Helper Methods #################################################################### + @registry.on( + "setup_connections", NodePubTable, namespace=f"{__name__}.PollerService" + ) async def setup_connections(self, node_pub_table: NodePubTable): # Create a subscriber @@ -118,6 +104,5 @@ async def update_data(self, datas: Dict[str, bytes]): # If update on the follow and all inputs available, then use the inputs if follow_event and len(self.in_bound_data) == len(self.in_bound): - await self.eventbus.asend( - Event("in_step", NewInBoundDataEvent(self.in_bound_data)) - ) + await self.entrypoint.emit("in_step", self.in_bound_data) + self.emit_counter += 1 diff --git a/chimerapy/engine/node/processor_service.py b/chimerapy/engine/node/processor_service.py index f161f746..e6fdc165 100644 --- a/chimerapy/engine/node/processor_service.py +++ b/chimerapy/engine/node/processor_service.py @@ -4,23 +4,22 @@ import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional +from aiodistbus import registry + from chimerapy.engine import _logger -from ..eventbus import Event, EventBus, TypedObserver +from ..data_protocols import ( + GatherData, + RegisteredMethod, + RegisteredMethodData, + ResultsData, +) from ..networking import DataChunk -from ..networking.client import Client -from ..networking.enums import NODE_MESSAGE from ..service import Service from ..states import NodeState -from .events import ( - GatherEvent, - NewInBoundDataEvent, - NewOutBoundDataEvent, - RegisteredMethodEvent, -) -from .registered_method import RegisteredMethod class ProcessorService(Service): @@ -28,7 +27,6 @@ def __init__( self, name: str, state: NodeState, - eventbus: EventBus, in_bound_data: bool, setup_fn: Optional[Callable] = None, main_fn: Optional[Callable] = None, @@ -42,7 +40,6 @@ def __init__( # Saving input parameters self.state = state - self.eventbus = eventbus self.setup_fn = setup_fn self.main_fn = main_fn self.teardown_fn = teardown_fn @@ -63,34 +60,13 @@ def __init__( self.running_task: Optional[asyncio.Task] = None self.tasks: List[asyncio.Task] = [] self.main_thread: Optional[threading.Thread] = None - - async def async_init(self): - - # Put observers - self.observers: Dict[str, TypedObserver] = { - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "stop": TypedObserver("stop", on_asend=self.stop, handle_event="drop"), - "registered_method": TypedObserver( - "registered_method", - RegisteredMethodEvent, - on_asend=self.execute_registered_method, - handle_event="unpack", - ), - "gather": TypedObserver( - "gather", GatherEvent, on_asend=self.gather, handle_event="unpack" - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) + self.executor: ThreadPoolExecutor = ThreadPoolExecutor() #################################################################### ## Lifecycle Hooks #################################################################### + @registry.on("setup", namespace=f"{__name__}.ProcessorService") async def setup(self): # Create threading information @@ -100,6 +76,7 @@ async def setup(self): if self.setup_fn: await self.safe_exec(self.setup_fn) + @registry.on("start", namespace=f"{__name__}.ProcessorService") async def start(self): # Create a task self.running_task = asyncio.create_task(self.main()) @@ -124,20 +101,12 @@ async def main(self): elif self.operation_mode == "step": - # self.logger.debug(f"{self}: operational mode = step") - # If step or sink node, only run with inputs if self.in_bound_data: - observer = TypedObserver( - "in_step", - NewInBoundDataEvent, - on_asend=self.safe_step, - handle_event="unpack", + # self.logger.debug(f"{self}: step node: {self.state.id}") + await self.entrypoint.on( + "in_step", self.safe_step, Dict[str, DataChunk] ) - await self.eventbus.asubscribe(observer) - self.observers["in_step"] = observer - - # self.logger.debug(f"{self}: step or sink node: {self.state.id}") # If source, run as fast as possible else: @@ -145,9 +114,11 @@ async def main(self): while self.running: await self.safe_step() + @registry.on("stop", namespace=f"{__name__}.ProcessorService") async def stop(self): self.running = False + @registry.on("teardown", namespace=f"{__name__}.ProcessorService") async def teardown(self): # Stop things @@ -168,74 +139,67 @@ async def teardown(self): ## Debugging tools #################################################################### - async def gather(self, client: Client): - await client.async_send( - signal=NODE_MESSAGE.REPORT_GATHER, - data={ - "node_id": self.state.id, - "latest_value": self.latest_data_chunk.to_json(), - }, - ) + @registry.on("gather", namespace=f"{__name__}.ProcessorService") + async def gather(self): + data = GatherData(node_id=self.state.id, output=self.latest_data_chunk) + await self.entrypoint.emit("gather_results", data) #################################################################### ## Async Registered Methods #################################################################### + @registry.on( + "registered_method", + RegisteredMethodData, + namespace=f"{__name__}.ProcessorService", + ) async def execute_registered_method( - self, method_name: str, params: Dict, client: Optional[Client] - ) -> Dict[str, Any]: + self, reg_method_req: RegisteredMethodData + ) -> ResultsData: # First check if the request is valid - if method_name not in self.registered_methods: - results = { - "node_id": self.state.id, - "node_state": self.state.to_json(), - "success": False, - "output": None, - } + if reg_method_req.method_name not in self.registered_methods: self.logger.warning( f"{self}: Worker requested execution of registered method that doesn't " - f"exists: {method_name}" + f"exists: {reg_method_req.method_name}" ) - return {"success": False, "output": None, "node_id": self.state.id} + data = ResultsData(node_id=self.state.id, success=False, output=None) + return data # Extract the method - function: Callable[[], Coroutine] = self.registered_node_fns[method_name] - style = self.registered_methods[method_name].style - # self.logger.debug(f"{self}: executing {function} with params: {params}") + function: Callable[[], Coroutine] = self.registered_node_fns[ + reg_method_req.method_name + ] + style = self.registered_methods[reg_method_req.method_name].style + self.logger.debug( + f"{self}: executing {function} with params: {reg_method_req.params}" + ) # Execute method based on its style success = False if style == "concurrent": # output = await function(**params) # type: ignore[call-arg] - output, _ = await self.safe_exec(function, kwargs=params) + output, _ = await self.safe_exec(function, kwargs=reg_method_req.params) success = True elif style == "blocking": with self.step_lock: - output, _ = await self.safe_exec(function, kwargs=params) + output, _ = await self.safe_exec(function, kwargs=reg_method_req.params) success = True elif style == "reset": with self.step_lock: - output, _ = await self.safe_exec(function, kwargs=params) - await self.eventbus.asend(Event("reset")) + output, _ = await self.safe_exec(function, kwargs=reg_method_req.params) + await self.entrypoint.emit("reset") success = True else: self.logger.error(f"Invalid registered method request: style={style}") output = None - # Sending the information if client - if client: - results = { - "success": success, - "output": output, - "node_id": self.state.id, - } - await client.async_send(signal=NODE_MESSAGE.REPORT_RESULTS, data=results) - - return {"success": success, "output": output} + data = ResultsData(node_id=self.state.id, success=success, output=output) + await self.entrypoint.emit("registered_method_results", data) + return data #################################################################### ## Helper Methods @@ -247,14 +211,15 @@ async def safe_exec( # Default value output = None + tic = time.perf_counter() try: - tic = time.perf_counter() if asyncio.iscoroutinefunction(func): output = await func(*args, **kwargs) else: - await asyncio.sleep(1 / 1000) # Allow other functions to run as well - output = func(*args, **kwargs) + output = await asyncio.get_running_loop().run_in_executor( + self.executor, lambda: func(*args, **kwargs) + ) except Exception: traceback_info = traceback.format_exc() self.logger.error(traceback_info) @@ -300,8 +265,8 @@ async def safe_step(self, data_chunks: Dict[str, DataChunk] = {}): output_data_chunk.update("meta", meta) # Send out the output to the OutputsHandler - event_data = NewOutBoundDataEvent(output_data_chunk) - await self.eventbus.asend(Event("out_step", event_data)) + await self.entrypoint.emit("out_step", output_data_chunk) + # self.logger.debug(f"{self}: output = {output_data_chunk}") # Update the counter self.step_id += 1 diff --git a/chimerapy/engine/node/profiler_service.py b/chimerapy/engine/node/profiler_service.py index 2e98c1e5..266d4e51 100644 --- a/chimerapy/engine/node/profiler_service.py +++ b/chimerapy/engine/node/profiler_service.py @@ -6,28 +6,24 @@ from typing import Any, Dict, List, Optional import pandas as pd +from aiodistbus import registry from psutil import Process from chimerapy.engine import config from ..async_timer import AsyncTimer from ..data_protocols import NodeDiagnostics -from ..eventbus import Event, EventBus, TypedObserver from ..networking.data_chunk import DataChunk from ..service import Service from ..states import NodeState -from .events import DiagnosticsReportEvent, EnableDiagnosticsEvent, NewOutBoundDataEvent class ProfilerService(Service): - def __init__( - self, name: str, state: NodeState, eventbus: EventBus, logger: logging.Logger - ): + def __init__(self, name: str, state: NodeState, logger: logging.Logger): super().__init__(name=name) # Save parameters self.state = state - self.eventbus = eventbus self.logger = logger # State variables @@ -44,61 +40,34 @@ def __init__( self.diagnostics_report, config.get("diagnostics.interval") ) - if self.state.logdir: - self.log_file = self.state.logdir / "diagnostics.csv" - else: - raise RuntimeError(f"{self}: logdir {self.state.logdir} not set!") - - async def async_init(self): - - # Add observers to profile - self.observers: Dict[str, TypedObserver] = { - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "enable_diagnostics": TypedObserver( - "enable_diagnostics", - EnableDiagnosticsEvent, - on_asend=self.enable, - handle_event="unpack", - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - - # self.logger.debug(f"{self}: log_file={self.log_file}") - + @registry.on("enable_diagnostics", bool, namespace=f"{__name__}.ProfilerService") async def enable(self, enable: bool = True): + # self.logger.debug(f"Profiling enabled: {enable}") if enable != self._enable: if enable: # Add a timer function await self.async_timer.start() - - # Add observer - observer = TypedObserver( - "out_step", - NewOutBoundDataEvent, - on_asend=self.post_step, - handle_event="unpack", - ) - self.observers["out_step"] = observer - await self.eventbus.asubscribe(observer) + await self.entrypoint.on("out_step", self.post_step, DataChunk) else: - # self.logger.debug(f"{self}: disabled") # Stop the timer and remove the observer await self.async_timer.stop() - - observer = self.observers["enable_diagnostics"] - await self.eventbus.aunsubscribe(observer) + await self.entrypoint.off("out_step") # Update self._enable = enable + @registry.on("setup", namespace=f"{__name__}.ProfilerService") async def setup(self): + + # Construct the path + if self.state.logdir: + self.log_file = self.state.logdir / "diagnostics.csv" + else: + raise RuntimeError(f"{self}: logdir {self.state.logdir} not set!") + self.process = Process(pid=os.getpid()) async def diagnostics_report(self): @@ -131,6 +100,7 @@ async def diagnostics_report(self): # Save information then diag = NodeDiagnostics( + node_id=self.state.id, timestamp=timestamp, latency=mean_latency, payload_size=total_payload, @@ -140,13 +110,15 @@ async def diagnostics_report(self): ) # Send the information to the Worker and ultimately the Manager - event_data = DiagnosticsReportEvent(diag) - # self.logger.debug(f"{self}: data = {diag}") - await self.eventbus.asend(Event("diagnostics_report", event_data)) + await self.entrypoint.emit("diagnostics_report", diag) + + # self.logger.debug(f"{self}: collected diagnostics") # Write to a csv, if diagnostics enabled if config.get("diagnostics.logging-enabled"): + # self.logger.debug(f"{self}: Saving data") + # Create dictionary with units data = { "timestamp": timestamp, @@ -167,6 +139,9 @@ async def diagnostics_report(self): ) async def post_step(self, data_chunk: DataChunk): + + # self.logger.debug(f"{self}: Received data chunk {data_chunk._uuid}.") + # assert self.process if not self.process: return None @@ -198,5 +173,6 @@ async def post_step(self, data_chunk: DataChunk): def get_object_kilobytes(self, payload: Any) -> float: return len(pickle.dumps(payload)) / 1024 + @registry.on("teardown", namespace=f"{__name__}.ProfilerService") async def teardown(self): await self.async_timer.stop() diff --git a/chimerapy/engine/node/publisher_service.py b/chimerapy/engine/node/publisher_service.py index cad59edc..b0f36399 100644 --- a/chimerapy/engine/node/publisher_service.py +++ b/chimerapy/engine/node/publisher_service.py @@ -1,9 +1,10 @@ import logging from typing import Dict, Optional +from aiodistbus import registry + from chimerapy.engine import _logger -from ..eventbus import EventBus, TypedObserver from ..networking import DataChunk, Publisher from ..service import Service from ..states import NodeState @@ -17,14 +18,12 @@ def __init__( self, name: str, state: NodeState, - eventbus: EventBus, logger: Optional[logging.Logger] = None, ): super().__init__(name) # Save information self.state = state - self.eventbus = eventbus # Logging if logger: @@ -32,33 +31,21 @@ def __init__( else: self.logger = _logger.getLogger("chimerapy-engine") - async def async_init(self): - - # Add observer - self.observers: Dict[str, TypedObserver] = { - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "out_step": TypedObserver( - "out_step", on_asend=self.publish, handle_event="unpack" - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - - def setup(self): + @registry.on("setup", namespace=f"{__name__}.PublisherService") + async def setup(self): # Creating publisher self.publisher = Publisher() self.publisher.start() self.state.port = self.publisher.port + @registry.on("out_step", DataChunk, namespace=f"{__name__}.PublisherService") async def publish(self, data_chunk: DataChunk): # self.logger.debug(f"{self}: publishing {data_chunk}") await self.publisher.publish(data_chunk.to_bytes()) - def teardown(self): + @registry.on("teardown", namespace=f"{__name__}.PublisherService") + async def teardown(self): # Shutting down publisher if self.publisher: diff --git a/chimerapy/engine/node/record_service.py b/chimerapy/engine/node/record_service.py index 4769fbf3..bda02cd6 100644 --- a/chimerapy/engine/node/record_service.py +++ b/chimerapy/engine/node/record_service.py @@ -4,9 +4,10 @@ import threading from typing import Dict, Optional +from aiodistbus import registry + from chimerapy.engine import _logger -from ..eventbus import EventBus, TypedObserver from ..records import ( AudioRecord, ImageRecord, @@ -27,14 +28,12 @@ def __init__( self, name: str, state: NodeState, - eventbus: EventBus, logger: Optional[logging.Logger] = None, ): super().__init__(name) # Saving parameters self.state = state - self.eventbus = eventbus # State variables self.save_queue: queue.Queue = queue.Queue() @@ -63,39 +62,35 @@ def __init__( else: self.logger = _logger.getLogger("chimerapy-engine") - async def async_init(self): - - # Put observers - self.observers: Dict[str, TypedObserver] = { - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "record": TypedObserver( - "record", on_asend=self.record, handle_event="drop" - ), - "collect": TypedObserver( - "collect", on_asend=self.collect, handle_event="drop" - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - #################################################################### ## Lifecycle Hooks #################################################################### - async def setup(self): + @registry.on("setup", namespace=f"{__name__}.RecordService") + def setup(self): # self.logger.debug(f"{self}: executing main") self._record_thread = threading.Thread(target=self.run) self._record_thread.start() - async def record(self): + @registry.on("record", namespace=f"{__name__}.RecordService") + def record(self): # self.logger.debug(f"{self}: Starting recording") ... - async def teardown(self): + @registry.on("collect", namespace=f"{__name__}.RecordService") + def collect(self): + # self.logger.debug(f"{self}: collecting recording") + + # Signal to stop and save + self.is_running.clear() + if self._record_thread: + self._record_thread.join() + + # self.logger.debug(f"{self}: Finish saving records") + + @registry.on("teardown", namespace=f"{__name__}.RecordService") + def teardown(self): # First, indicate the end self.is_running.clear() @@ -108,11 +103,14 @@ async def teardown(self): ## Helper Methods & Attributes #################################################################### - @property - def enabled(self) -> bool: - return self.state.fsm == "RECORDING" - + @registry.on("record_entry", Dict, namespace=f"{__name__}.RecordService") def submit(self, entry: Dict): + + # self.logger.debug(f"{self}: Received data: {entry}") + + if self.state.fsm != "RECORDING": + return None + self.save_queue.put(entry) def run(self): @@ -148,13 +146,3 @@ def run(self): entry.close() # self.logger.debug(f"{self}: Closed all entries") - - def collect(self): - # self.logger.debug(f"{self}: collecting recording") - - # Signal to stop and save - self.is_running.clear() - if self._record_thread: - self._record_thread.join() - - # self.logger.debug(f"{self}: Finish saving records") diff --git a/chimerapy/engine/node/registered_method.py b/chimerapy/engine/node/registered_method.py index 648e0a65..fdf59124 100644 --- a/chimerapy/engine/node/registered_method.py +++ b/chimerapy/engine/node/registered_method.py @@ -1,15 +1,6 @@ -from dataclasses import dataclass, field from typing import Callable, Dict, Literal -from dataclasses_json import dataclass_json - - -@dataclass_json -@dataclass -class RegisteredMethod: - name: str - style: str = "concurrent" # Literal['concurrent', 'blocking', 'reset'] - params: Dict[str, str] = field(default_factory=dict) +from ..data_protocols import RegisteredMethod # Reference: diff --git a/chimerapy/engine/node/worker_comms_service.py b/chimerapy/engine/node/worker_comms_service.py index 1f055fda..2b902759 100644 --- a/chimerapy/engine/node/worker_comms_service.py +++ b/chimerapy/engine/node/worker_comms_service.py @@ -3,19 +3,22 @@ import tempfile from typing import Dict, Optional -from ..data_protocols import NodeDiagnostics, NodePubTable -from ..eventbus import Event, EventBus, TypedObserver -from ..networking import Client -from ..networking.enums import GENERAL_MESSAGE, NODE_MESSAGE, WORKER_MESSAGE +from aiodistbus import registry + +from chimerapy.engine import config + +from ..data_protocols import ( + GatherData, + NodeDiagnostics, + NodePubTable, + PreSetupData, + RegisteredMethodData, + ResultsData, +) +from ..networking import Client, DataChunk +from ..networking.enums import NODE_MESSAGE, WORKER_MESSAGE from ..service import Service from ..states import NodeState -from .events import ( - DiagnosticsReportEvent, - EnableDiagnosticsEvent, - GatherEvent, - ProcessNodePubTableEvent, - RegisteredMethodEvent, -) from .node_config import NodeConfig @@ -32,7 +35,6 @@ def __init__( worker_logging_port: int = 5555, state: Optional[NodeState] = None, logger: Optional[logging.Logger] = None, - eventbus: Optional[EventBus] = None, ): super().__init__(name=name) @@ -47,7 +49,6 @@ def __init__( # Optional self.state = state self.logger = logger - self.eventbus = eventbus if worker_logdir: self.worker_logdir = worker_logdir @@ -58,47 +59,36 @@ def __init__( self.running: bool = False self.client: Optional[Client] = None - async def async_init(self): - - assert self.state and self.eventbus and self.logger - - observers: Dict[str, TypedObserver] = { - "setup": TypedObserver("setup", on_asend=self.setup, handle_event="drop"), - "NodeState.changed": TypedObserver( - "NodeState.changed", on_asend=self.send_state, handle_event="drop" - ), - "diagnostics_report": TypedObserver( - "diagnostics_report", - DiagnosticsReportEvent, - on_asend=self.send_diagnostics, - handle_event="unpack", - ), - "teardown": TypedObserver( - "teardown", on_asend=self.teardown, handle_event="drop" - ), - } - for ob in observers.values(): - await self.eventbus.asubscribe(ob) - #################################################################### ## Helper Functions #################################################################### - def in_node_config( - self, state: NodeState, logger: logging.Logger, eventbus: EventBus - ): + @registry.on("pre_setup", PreSetupData, namespace=f"{__name__}.WorkerCommsService") + def in_node_config(self, presetup_data: PreSetupData): # Save parameters - self.state = state - self.logger = logger - self.eventbus = eventbus + self.state = presetup_data.state + self.logger = presetup_data.logger + + if self.worker_config: + config.update_defaults(self.worker_config) + + def check(self) -> bool: + if not self.logger: + raise RuntimeError(f"{self}: logger not set!") + if not self.state: + self.logger.error(f"{self}: NodeState not set!") + return False + return True #################################################################### ## Lifecycle Hooks #################################################################### + @registry.on("setup", namespace=f"{__name__}.WorkerCommsService") async def setup(self): - assert self.state and self.eventbus and self.logger + if not self.check(): + return # self.logger.debug( # f"{self}: Prepping the networking component of the Node, connecting to " @@ -111,7 +101,6 @@ async def setup(self): port=self.port, id=self.state.id, ws_handlers={ - GENERAL_MESSAGE.SHUTDOWN: self.shutdown, WORKER_MESSAGE.BROADCAST_NODE_SERVER: self.process_node_pub_table, WORKER_MESSAGE.REQUEST_STEP: self.async_step, WORKER_MESSAGE.REQUEST_COLLECT: self.provide_collect, @@ -123,13 +112,13 @@ async def setup(self): WORKER_MESSAGE.DIAGNOSTICS: self.enable_diagnostics, }, parent_logger=self.logger, - thread=self.eventbus.thread, ) await self.client.async_connect() # Send publisher port and host information await self.send_state() + @registry.on("teardown", namespace=f"{__name__}.WorkerCommsService") async def teardown(self): # Shutdown the client @@ -141,10 +130,14 @@ async def teardown(self): ## Message Requests #################################################################### - async def send_state(self): - assert self.state and self.eventbus and self.logger + @registry.on( + "NodeState.changed", NodeState, namespace=f"{__name__}.WorkerCommsService" + ) + async def send_state(self, state: Optional[NodeState] = None): + if not self.check(): + return - # Save container informaiton + # self.logger.debug(f"{self}: Sending NodeState: {self.state.to_dict()}") jsonable_state = self.state.to_dict() jsonable_state["logdir"] = str(jsonable_state["logdir"]) if self.client: @@ -153,69 +146,110 @@ async def send_state(self): ) async def provide_gather(self, msg: Dict): - assert self.state and self.eventbus and self.logger + if not self.check(): + return + + # self.logger.debug(f"{self}: Sending gather") + + if self.client: + # self.logger.debug(f"{self}: Sending gather with client") + await self.entrypoint.emit("gather") + + @registry.on( + "gather_results", GatherData, namespace=f"{__name__}.WorkerCommsService" + ) + async def send_gather(self, gather_data: GatherData): + if not self.check(): + return + + # self.logger.debug(f"{self}: Sending gather results: {gather_data}") + + # If gather data is DataChunk, serialize it + if isinstance(gather_data.output, DataChunk): + gather_data.output = gather_data.output.to_json() if self.client: - event_data = GatherEvent(self.client) - await self.eventbus.asend(Event("gather", event_data)) + await self.client.async_send( + signal=NODE_MESSAGE.REPORT_GATHER, data=gather_data.to_dict() + ) + @registry.on( + "report_diagnostics", + NodeDiagnostics, + namespace=f"{__name__}.WorkerCommsService", + ) async def send_diagnostics(self, diagnostics: NodeDiagnostics): - assert self.state and self.eventbus and self.logger + if not self.check(): + return # self.logger.debug(f"{self}: Sending Diagnostics") if self.client: - data = {"node_id": self.state.id, "diagnostics": diagnostics.to_dict()} - await self.client.async_send(signal=NODE_MESSAGE.DIAGNOSTICS, data=data) + await self.client.async_send( + signal=NODE_MESSAGE.DIAGNOSTICS, data=diagnostics.to_dict() + ) #################################################################### ## Message Responds #################################################################### async def process_node_pub_table(self, msg: Dict): - assert self.state and self.eventbus and self.logger - - node_pub_table = NodePubTable.from_dict(msg["data"]) + if not self.check(): + return # Pass the information to the Poller Service - event_data = ProcessNodePubTableEvent(node_pub_table) - await self.eventbus.asend(Event("setup_connections", event_data)) + node_pub_table = NodePubTable.from_dict(msg["data"]) + await self.entrypoint.emit("setup_connections", node_pub_table) async def start_node(self, msg: Dict = {}): - assert self.state and self.eventbus and self.logger - await self.eventbus.asend(Event("start")) + if not self.check(): + return + await self.entrypoint.emit("start") async def record_node(self, msg: Dict): - assert self.state and self.eventbus and self.logger - await self.eventbus.asend(Event("record")) + if not self.check(): + return + await self.entrypoint.emit("record") async def stop_node(self, msg: Dict): - assert self.state and self.eventbus and self.logger - await self.eventbus.asend(Event("stop")) + if not self.check(): + return + await self.entrypoint.emit("stop") async def provide_collect(self, msg: Dict): - assert self.state and self.eventbus and self.logger - await self.eventbus.asend(Event("collect")) + if not self.check(): + return + await self.entrypoint.emit("collect") async def execute_registered_method(self, msg: Dict): - assert self.state and self.eventbus and self.logger + if not self.check(): + return + # Check first that the method exists - method_name, params = (msg["data"]["method_name"], msg["data"]["params"]) + event_data = RegisteredMethodData.from_dict(msg["data"]) # Send the event + await self.entrypoint.emit("registered_method", event_data) + + @registry.on( + "registered_method_results", + ResultsData, + namespace=f"{__name__}.WorkerCommsService", + ) + async def report_registered_method_results(self, results: ResultsData): + if not self.check(): + return + + # Send the results if self.client: - event_data = RegisteredMethodEvent( - method_name=method_name, params=params, client=self.client + await self.client.async_send( + signal=NODE_MESSAGE.REPORT_RESULTS, data=results.to_dict() ) - await self.eventbus.asend(Event("registered_method", event_data)) async def async_step(self, msg: Dict): - assert self.state and self.eventbus and self.logger - await self.eventbus.asend(Event("manual_step")) + if not self.check(): + return + await self.entrypoint.emit("manual_step") async def enable_diagnostics(self, msg: Dict): - assert self.state and self.eventbus and self.logger - enable = msg["data"]["enable"] - - event_data = EnableDiagnosticsEvent(enable) - await self.eventbus.asend(Event("enable_diagnostics", event_data)) + await self.entrypoint.emit("enable_diagnostics", msg["data"]["enable"]) diff --git a/chimerapy/engine/service.py b/chimerapy/engine/service.py index cb44af38..28fca954 100644 --- a/chimerapy/engine/service.py +++ b/chimerapy/engine/service.py @@ -1,49 +1,31 @@ -from collections import UserDict -from typing import Any, Dict, List, Optional +from typing import Optional + +from aiodistbus import EntryPoint, EventBus, registry class Service: def __init__(self, name: str): self.name = name + self.entrypoint = EntryPoint() def __str__(self): return f"<{self.__class__.__name__}, name={self.name}>" - def shutdown(self): - ... - - -class ServiceGroup(UserDict): - - data: Dict[str, Service] - - def apply(self, method_name: str, order: Optional[List[str]] = None): + async def attach(self, bus: EventBus): + """Attach the service to the bus. - if order: - for s_name in order: - if s_name in self.data: - s = self.data[s_name] - func = getattr(s, method_name) - func() - else: - for s in self.data.values(): - func = getattr(s, method_name) - func() + This is where the service should register its entrypoint and connect to the bus. - async def async_apply( - self, method_name: str, order: Optional[List[str]] = None - ) -> List[Any]: + Args: + bus (EventBus): The bus to attach to. - outputs: List[Any] = [] - if order: - for s_name in order: - if s_name in self.data: - s = self.data[s_name] - func = getattr(s, method_name) - outputs.append(await func()) - else: - for s in self.data.values(): - func = getattr(s, method_name) - outputs.append(await func()) + Raises: + ValueError: If the registry is empty for service's Namespace. - return outputs + """ + await self.entrypoint.connect(bus) + await self.entrypoint.use( + registry, + b_args=[self], + namespace=f"{self.__class__.__module__}.{self.__class__.__name__}", + ) diff --git a/chimerapy/engine/utils.py b/chimerapy/engine/utils.py index 92c42d74..2d37d2d2 100644 --- a/chimerapy/engine/utils.py +++ b/chimerapy/engine/utils.py @@ -57,6 +57,13 @@ async def wrapper(): return wrapper(), future +def run_coroutine_in_thread(coro: Coroutine) -> Future: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + future = loop.run_until_complete(coro) + return future + + async def async_waiting_for( condition: Callable[[], bool], check_period: Union[int, float] = 0.1, diff --git a/chimerapy/engine/worker/http_client_service.py b/chimerapy/engine/worker/http_client_service.py index 26e2a683..78865f55 100644 --- a/chimerapy/engine/worker/http_client_service.py +++ b/chimerapy/engine/worker/http_client_service.py @@ -7,20 +7,20 @@ import socket import traceback import uuid -from typing import Dict, Literal, Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import aiohttp +from aiodistbus import registry from zeroconf import ServiceBrowser, Zeroconf from chimerapy.engine import _logger, config -from ..eventbus import EventBus, TypedObserver +from ..data_protocols import ConnectData from ..logger.zmq_handlers import NodeIDZMQPullListener from ..networking import Client from ..service import Service from ..states import WorkerState from ..utils import get_ip_address -from .events import SendArchiveEvent from .zeroconf_listener import ZeroconfListener @@ -29,7 +29,6 @@ def __init__( self, name: str, state: WorkerState, - eventbus: EventBus, logger: logging.Logger, logreceiver: NodeIDZMQPullListener, ): @@ -37,7 +36,6 @@ def __init__( # Input parameters self.state = state - self.eventbus = eventbus self.logger = logger self.logreceiver = logreceiver @@ -51,31 +49,18 @@ def __init__( # Services self.http_client = aiohttp.ClientSession() - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - "WorkerState.changed": TypedObserver( - "WorkerState.changed", - on_asend=self._async_node_status_update, - handle_event="drop", - ), - "send_archive": TypedObserver( - "send_archive", - SendArchiveEvent, - on_asend=self._send_archive, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - def get_address(self) -> Tuple[str, int]: return self.manager_host, self.manager_port + @registry.on("connect", ConnectData, namespace=f"{__name__}.HttpClientService") + async def async_connect_handler(self, connect_data: ConnectData): + await self.async_connect( + host=connect_data.host, + port=connect_data.port, + method=connect_data.method, + ) + + @registry.on("shutdown", namespace=f"{__name__}.HttpClientService") async def shutdown(self) -> bool: success = True @@ -255,6 +240,7 @@ async def _async_connect_via_zeroconf( return success + @registry.on("send_archive", pathlib.Path, f"{__name__}.HttpClientService") async def _send_archive(self, path: pathlib.Path) -> bool: # Flag @@ -338,7 +324,12 @@ async def _send_archive_remotely(self, host: str, port: int) -> bool: return False - async def _async_node_status_update(self) -> bool: + @registry.on("WorkerState.changed", WorkerState, f"{__name__}.HttpClientService") + async def _async_node_status_update( + self, state: Optional[WorkerState] = None + ) -> bool: + + # self.logger.debug(f"{self}: sending node status update") if not self.connected_to_manager: return False diff --git a/chimerapy/engine/worker/http_server_service.py b/chimerapy/engine/worker/http_server_service.py index 8eaba1d3..1b4259c4 100644 --- a/chimerapy/engine/worker/http_server_service.py +++ b/chimerapy/engine/worker/http_server_service.py @@ -1,35 +1,28 @@ import asyncio -import enum +import json import logging import pathlib import pickle from typing import Dict, List +from aiodistbus import EventBus, registry from aiohttp import web from ..data_protocols import ( + GatherData, NodeDiagnostics, NodePubEntry, NodePubTable, + RegisteredMethodData, + ResultsData, + ServerMessage, ) -from ..eventbus import Event, EventBus, TypedObserver -from ..networking import Server +from ..networking import DataChunk, Server from ..networking.enums import NODE_MESSAGE +from ..node import NodeConfig from ..service import Service from ..states import NodeState, WorkerState from ..utils import update_dataclass -from .events import ( - BroadcastEvent, - CreateNodeEvent, - DestroyNodeEvent, - EnableDiagnosticsEvent, - ProcessNodePubTableEvent, - RegisteredMethodEvent, - SendArchiveEvent, - SendMessageEvent, - UpdateGatherEvent, - UpdateResultsEvent, -) class HttpServerService(Service): @@ -37,14 +30,13 @@ def __init__( self, name: str, state: WorkerState, - eventbus: EventBus, logger: logging.Logger, ): + super().__init__(name=name) # Save input parameters self.name = name self.state = state - self.eventbus = eventbus self.logger = logger # Containers @@ -91,30 +83,7 @@ def port(self) -> int: def url(self) -> str: return f"http://{self._ip}:{self._port}" - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - "broadcast": TypedObserver( - "broadcast", - BroadcastEvent, - on_asend=self._async_broadcast, - handle_event="unpack", - ), - "send": TypedObserver( - "send", - SendMessageEvent, - on_asend=self._async_send, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("start", namespace=f"{__name__}.HttpServerService") async def start(self): # Runn the Server @@ -126,8 +95,9 @@ async def start(self): self.state.port = self.port # After updatign the information, then run it! - await self.eventbus.asend(Event("after_server_startup")) + await self.entrypoint.emit("after_server_startup") + @registry.on("shutdown", namespace=f"{__name__}.HttpServerService") async def shutdown(self) -> bool: return await self.server.async_shutdown() @@ -135,13 +105,18 @@ async def shutdown(self) -> bool: ## Helper Functions #################################################################### - async def _async_send(self, client_id: str, signal: enum.Enum, data: Dict) -> bool: + @registry.on("send", ServerMessage, namespace=f"{__name__}.HttpServerService") + async def _async_send(self, msg: ServerMessage) -> bool: + if not isinstance(msg.client_id, str): + self.logger.error(f"{self}: Missing client_id") + return False return await self.server.async_send( - client_id=client_id, signal=signal, data=data + client_id=msg.client_id, signal=msg.signal, data=msg.data ) - async def _async_broadcast(self, signal: enum.Enum, data: Dict) -> bool: - return await self.server.async_broadcast(signal=signal, data=data) + @registry.on("broadcast", ServerMessage, namespace=f"{__name__}.HttpServerService") + async def _async_broadcast(self, msg: ServerMessage) -> bool: + return await self.server.async_broadcast(signal=msg.signal, data=msg.data) def _create_node_pub_table(self) -> NodePubTable: @@ -154,12 +129,12 @@ def _create_node_pub_table(self) -> NodePubTable: return node_pub_table async def _collect_and_send(self, path: pathlib.Path): + # Collect data from the Nodes - await self.eventbus.asend(Event("collect")) + await self.entrypoint.emit("collect") # After collecting, request to send the archive - event_data = SendArchiveEvent(path) - await self.eventbus.asend(Event("send_archive", event_data)) + await self.entrypoint.emit("send_archive", path) #################################################################### ## HTTP Routes @@ -216,70 +191,73 @@ async def _collect_and_send(self, path: pathlib.Path): # return web.HTTPOk() async def _async_create_node_route(self, request: web.Request) -> web.Response: + msg_bytes = await request.read() # Create node - node_config = pickle.loads(msg_bytes) - await self.eventbus.asend(Event("create_node", CreateNodeEvent(node_config))) + node_config: NodeConfig = pickle.loads(msg_bytes) + await self.entrypoint.emit("create_node", node_config) return web.HTTPOk() async def _async_destroy_node_route(self, request: web.Request) -> web.Response: + msg = await request.json() # Destroy Node - node_id = msg["id"] - await self.eventbus.asend(Event("destroy_node", DestroyNodeEvent(node_id))) + node_id: str = msg["id"] + await self.entrypoint.emit("destroy_node", node_id) return web.HTTPOk() async def _async_get_node_pub_table(self, request: web.Request) -> web.Response: - node_pub_table = self._create_node_pub_table() return web.json_response(node_pub_table.to_json()) async def _async_process_node_pub_table(self, request: web.Request) -> web.Response: + msg = await request.json() node_pub_table: NodePubTable = NodePubTable.from_dict(msg) # Broadcasting the node server data - await self.eventbus.asend( - Event("process_node_pub_table", ProcessNodePubTableEvent(node_pub_table)) - ) + await self.entrypoint.emit("process_node_pub_table", node_pub_table) return web.HTTPOk() async def _async_step_route(self, request: web.Request) -> web.Response: - await self.eventbus.asend(Event("step_nodes")) + + await self.entrypoint.emit("step_nodes") return web.HTTPOk() async def _async_start_nodes_route(self, request: web.Request) -> web.Response: - await self.eventbus.asend(Event("start_nodes")) + + await self.entrypoint.emit("start_nodes") return web.HTTPOk() async def _async_record_route(self, request: web.Request) -> web.Response: - await self.eventbus.asend(Event("record_nodes")) + + await self.entrypoint.emit("record_nodes") return web.HTTPOk() async def _async_request_method_route(self, request: web.Request) -> web.Response: + msg = await request.json() # Get event information - event_data = RegisteredMethodEvent( - node_id=msg["node_id"], method_name=msg["method_name"], params=msg["params"] - ) + reg_method_data = RegisteredMethodData.from_dict(msg) # Send it! - await self.eventbus.asend(Event("registered_method", event_data)) - + await self.entrypoint.emit("registered_method", reg_method_data) return web.HTTPOk() async def _async_stop_nodes_route(self, request: web.Request) -> web.Response: - await self.eventbus.asend(Event("stop_nodes")) + + await self.entrypoint.emit("stop_nodes") return web.HTTPOk() async def _async_report_node_gather(self, request: web.Request) -> web.Response: - await self.eventbus.asend(Event("gather_nodes")) + + await self.entrypoint.emit("gather_nodes") self.logger.warning(f"{self}: gather doesn't work ATM.") gather_data = {"id": self.state.id, "node_data": {}} @@ -291,17 +269,18 @@ async def _async_collect(self, request: web.Request) -> web.Response: return web.HTTPOk() async def _async_diagnostics_route(self, request: web.Request) -> web.Response: + data = await request.json() # Determine if enable/disable - event_data = EnableDiagnosticsEvent(data["enable"]) - await self.eventbus.asend(Event("diagnostics", event_data)) + enable: bool = data["enable"] + await self.entrypoint.emit("diagnostics", enable) return web.HTTPOk() async def _async_shutdown_route(self, request: web.Request) -> web.Response: - # Execute shutdown after returning HTTPOk (prevent Manager stuck waiting) - self.tasks.append(asyncio.create_task(self.eventbus.asend(Event("shutdown")))) + # Execute shutdown after returning HTTPOk (prevent Manager stuck waiting) + self.tasks.append(asyncio.create_task(self.entrypoint.emit("shutdown"))) return web.HTTPOk() #################################################################### @@ -310,45 +289,30 @@ async def _async_shutdown_route(self, request: web.Request) -> web.Response: async def _async_node_status_update(self, msg: Dict, ws: web.WebSocketResponse): - # self.logger.debug(f"{self}: note_status_update: :{msg}") + # self.logger.debug(f"{self}: node_status_update: :{msg}") node_state = NodeState.from_dict(msg["data"]) - node_id = node_state.id # Update our records by grabbing all data from the msg - if node_id in self.state.nodes and node_state: + if node_state.id in self.state.nodes and node_state: # Update the node state - update_dataclass(self.state.nodes[node_id], node_state) - await self.eventbus.asend(Event("WorkerState.changed", self.state)) + update_dataclass(self.state.nodes[node_state.id], node_state) + await self.entrypoint.emit("WorkerState.changed", self.state) async def _async_node_report_gather(self, msg: Dict, ws: web.WebSocketResponse): - - # Saving gathering value - node_id = msg["data"]["node_id"] - - await self.eventbus.asend( - Event( - "update_gather", - UpdateGatherEvent(node_id=node_id, gather=msg["data"]["latest_value"]), - ) - ) + gather_data = GatherData.from_dict(msg["data"]) + gather_data.output = DataChunk.from_json(gather_data.output) + await self.entrypoint.emit("update_gather", gather_data) async def _async_node_report_results(self, msg: Dict, ws: web.WebSocketResponse): - - node_id = msg["data"]["node_id"] - await self.eventbus.asend( - Event( - "update_results", - UpdateResultsEvent(node_id=node_id, results=msg["data"]["output"]), - ) - ) + results = ResultsData.from_dict(msg["data"]) + await self.entrypoint.emit("update_results", results) async def _async_node_diagnostics(self, msg: Dict, ws: web.WebSocketResponse): # self.logger.debug(f"{self}: received diagnostics: {msg}") # Create the entry and update the table - node_id: str = msg["data"]["node_id"] - diag = NodeDiagnostics.from_dict(msg["data"]["diagnostics"]) - if node_id in self.state.nodes: - self.state.nodes[node_id].diagnostics = diag + diag = NodeDiagnostics.from_dict(msg["data"]) + if diag.node_id in self.state.nodes: + self.state.nodes[diag.node_id].diagnostics = diag diff --git a/chimerapy/engine/worker/node_handler_service/context_session.py b/chimerapy/engine/worker/node_handler_service/context_session.py index e74d0fb7..a4806b73 100644 --- a/chimerapy/engine/worker/node_handler_service/context_session.py +++ b/chimerapy/engine/worker/node_handler_service/context_session.py @@ -54,7 +54,7 @@ def shutdown(self): class MPSession(ContextSession): def __init__(self): self.loop = asyncio.get_running_loop() - self.pool = mp.Pool(processes=1) + self.pool = mp.Pool() self.executor = MultiprocessExecutor(self.pool) self.futures = [] diff --git a/chimerapy/engine/worker/node_handler_service/node_handler_service.py b/chimerapy/engine/worker/node_handler_service/node_handler_service.py index ff82be78..8bbaf3c2 100644 --- a/chimerapy/engine/worker/node_handler_service/node_handler_service.py +++ b/chimerapy/engine/worker/node_handler_service/node_handler_service.py @@ -6,11 +6,17 @@ # Third-party Imports import dill import multiprocess as mp +from aiodistbus import registry from chimerapy.engine import config -from ...data_protocols import NodePubTable -from ...eventbus import Event, EventBus, TypedObserver +from ...data_protocols import ( + GatherData, + NodePubTable, + RegisteredMethodData, + ResultsData, + ServerMessage, +) from ...logger.zmq_handlers import NodeIDZMQPullListener from ...networking import DataChunk from ...networking.enums import WORKER_MESSAGE @@ -19,17 +25,6 @@ from ...service import Service from ...states import NodeState, WorkerState from ...utils import async_waiting_for -from ..events import ( - BroadcastEvent, - CreateNodeEvent, - DestroyNodeEvent, - EnableDiagnosticsEvent, - ProcessNodePubTableEvent, - RegisteredMethodEvent, - SendMessageEvent, - UpdateGatherEvent, - UpdateResultsEvent, -) from .context_session import ContextSession, MPSession, ThreadSession from .node_controller import MPNodeController, NodeController, ThreadNodeController @@ -39,7 +34,6 @@ def __init__( self, name: str, state: WorkerState, - eventbus: EventBus, logger: logging.Logger, logreceiver: NodeIDZMQPullListener, ): @@ -47,7 +41,6 @@ def __init__( # Input parameters self.state = state - self.eventbus = eventbus self.logger = logger self.logreceiver = logreceiver @@ -61,78 +54,7 @@ def __init__( "threading": ThreadNodeController, } - async def async_init(self): - - # Specify observers - self.observers: Dict[str, TypedObserver] = { - "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), - "shutdown": TypedObserver( - "shutdown", on_asend=self.shutdown, handle_event="drop" - ), - "create_node": TypedObserver( - "create_node", - CreateNodeEvent, - on_asend=self.async_create_node, - handle_event="unpack", - ), - "destroy_node": TypedObserver( - "destroy_node", - DestroyNodeEvent, - on_asend=self.async_destroy_node, - handle_event="unpack", - ), - "process_node_pub_table": TypedObserver( - "process_node_pub_table", - ProcessNodePubTableEvent, - on_asend=self.async_process_node_pub_table, - handle_event="unpack", - ), - "step_nodes": TypedObserver( - "step_nodes", on_asend=self.async_step, handle_event="drop" - ), - "start_nodes": TypedObserver( - "start_nodes", on_asend=self.async_start_nodes, handle_event="drop" - ), - "stop_nodes": TypedObserver( - "stop_nodes", on_asend=self.async_stop_nodes, handle_event="drop" - ), - "record_nodes": TypedObserver( - "record_nodes", on_asend=self.async_record_nodes, handle_event="drop" - ), - "registered_method": TypedObserver( - "registered_method", - RegisteredMethodEvent, - on_asend=self.async_request_registered_method, - handle_event="unpack", - ), - "collect": TypedObserver( - "collect", on_asend=self.async_collect, handle_event="drop" - ), - "gather_nodes": TypedObserver( - "gather_nodes", on_asend=self.async_gather, handle_event="drop" - ), - "diagnostics": TypedObserver( - "diagnostics", - EnableDiagnosticsEvent, - on_asend=self.async_diagnostics, - handle_event="unpack", - ), - "update_gather": TypedObserver( - "update_gather", - UpdateGatherEvent, - on_asend=self.update_gather, - handle_event="unpack", - ), - "update_results": TypedObserver( - "update_results", - UpdateResultsEvent, - on_asend=self.update_results, - handle_event="unpack", - ), - } - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - + @registry.on("start", namespace=f"{__name__}.NodeHandlerService") async def start(self) -> bool: # Containers self.mp_session = MPSession() @@ -143,6 +65,7 @@ async def start(self) -> bool: } return True + @registry.on("shutdown", namespace=f"{__name__}.NodeHandlerService") async def shutdown(self) -> bool: tasks = [ @@ -164,18 +87,27 @@ async def shutdown(self) -> bool: ## Helper Functions ################################################################################### - def update_gather(self, node_id: str, gather: Any): - self.node_controllers[node_id].gather = gather + @registry.on( + "update_gather", GatherData, namespace=f"{__name__}.NodeHandlerService" + ) + def update_gather(self, gather_data: GatherData): + node_id = gather_data.node_id + self.node_controllers[node_id].gather = gather_data.output self.node_controllers[node_id].response = True - def update_results(self, node_id: str, results: Any): - self.node_controllers[node_id].registered_method_results = results + @registry.on( + "update_results", ResultsData, namespace=f"{__name__}.NodeHandlerService" + ) + def update_results(self, results_data: ResultsData): + node_id = results_data.node_id + self.node_controllers[node_id].registered_method_results = results_data.output self.node_controllers[node_id].response = True ################################################################################### ## Node Handling ################################################################################### + @registry.on("create_node", NodeConfig, namespace=f"{__name__}.NodeHandlerService") async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool: # Ensure to convert the node_config into a NodeConfig object @@ -219,7 +151,6 @@ async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool: logging_level=self.logger.level, worker_logging_port=self.logreceiver.port, ) - # worker_comms.inject(node_object) node_object.add_worker_comms(worker_comms) # Create controller @@ -266,6 +197,7 @@ async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool: return success + @registry.on("destroy_node", str, namespace=f"{__name__}.NodeHandlerService") async def async_destroy_node(self, node_id: str) -> bool: # self.logger.debug(f"{self}: received request for Node {node_id} destruction") @@ -283,16 +215,19 @@ async def async_destroy_node(self, node_id: str) -> bool: return success + @registry.on( + "process_node_pub_table", + NodePubTable, + namespace=f"{__name__}.NodeHandlerService", + ) async def async_process_node_pub_table(self, node_pub_table: NodePubTable) -> bool: - await self.eventbus.asend( - Event( - "broadcast", - BroadcastEvent( - signal=WORKER_MESSAGE.BROADCAST_NODE_SERVER, - data=node_pub_table.to_dict(), - ), - ) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.BROADCAST_NODE_SERVER, + data=node_pub_table.to_dict(), + ), ) # Now wait until all nodes have responded as CONNECTED @@ -307,7 +242,7 @@ async def async_process_node_pub_table(self, node_pub_table: NodePubTable) -> bo success.append(True) break else: - self.logger.debug(f"{self}: Node {node_id} has connected: FAILED") + self.logger.error(f"{self}: Node {node_id} has connected: FAILED") success.append(False) if not all(success): @@ -315,31 +250,51 @@ async def async_process_node_pub_table(self, node_pub_table: NodePubTable) -> bo return all(success) + @registry.on("start_nodes", namespace=f"{__name__}.NodeHandlerService") async def async_start_nodes(self) -> bool: + # Send message to nodes to start - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.START_NODES)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.START_NODES, + ), ) return True + @registry.on("record_nodes", namespace=f"{__name__}.NodeHandlerService") async def async_record_nodes(self) -> bool: + # Send message to nodes to start - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.RECORD_NODES)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.RECORD_NODES, + ), ) return True + @registry.on("step_nodes", namespace=f"{__name__}.NodeHandlerService") async def async_step(self) -> bool: + # Worker tell all nodes to take a step - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.REQUEST_STEP)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.REQUEST_STEP, + ), ) return True + @registry.on("stop_nodes", namespace=f"{__name__}.NodeHandlerService") async def async_stop_nodes(self) -> bool: + # Send message to nodes to start - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.STOP_NODES)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.STOP_NODES, + ), ) await async_waiting_for( lambda: all( @@ -351,44 +306,46 @@ async def async_stop_nodes(self) -> bool: ) return True + @registry.on("registered_method", namespace=f"{__name__}.NodeHandlerService") async def async_request_registered_method( - self, node_id: str, method_name: str, params: Dict = {} + self, reg_method_data: RegisteredMethodData ) -> Dict[str, Any]: # Mark that the node hasn't responsed - self.node_controllers[node_id].response = False - self.logger.debug( - f"{self}: Requesting registered method: {method_name}@{node_id}" - ) - - event_data = SendMessageEvent( - client_id=node_id, - signal=WORKER_MESSAGE.REQUEST_METHOD, - data={"method_name": method_name, "params": params}, + self.node_controllers[reg_method_data.node_id].response = False + + await self.entrypoint.emit( + "send", + ServerMessage( + client_id=reg_method_data.node_id, + signal=WORKER_MESSAGE.REQUEST_METHOD, + data=reg_method_data.to_dict(), + ), ) - await self.eventbus.asend(Event("send", event_data)) # Then wait for the Node response success = await async_waiting_for( - condition=lambda: self.node_controllers[node_id].response is True, + condition=lambda: self.node_controllers[reg_method_data.node_id].response + is True, ) return { "success": success, - "output": self.node_controllers[node_id].registered_method_results, + "output": self.node_controllers[ + reg_method_data.node_id + ].registered_method_results, } + @registry.on("diagnostics", bool, namespace=f"{__name__}.NodeHandlerService") async def async_diagnostics(self, enable: bool) -> bool: - await self.eventbus.asend( - Event( - "broadcast", - BroadcastEvent( - signal=WORKER_MESSAGE.DIAGNOSTICS, data={"enable": enable} - ), - ) + + await self.entrypoint.emit( + "broadcast", + ServerMessage(signal=WORKER_MESSAGE.DIAGNOSTICS, data={"enable": enable}), ) return True + @registry.on("gather_nodes", namespace=f"{__name__}.NodeHandlerService") async def async_gather(self) -> Dict: # self.logger.debug(f"{self}: reporting to Manager gather request") @@ -397,14 +354,17 @@ async def async_gather(self) -> Dict: self.node_controllers[node_id].response = False # Request gather from Worker to Nodes - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.REQUEST_GATHER)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.REQUEST_GATHER, + ), ) # Wait until all Nodes have gather success = [] for node_id in self.state.nodes: - for i in range(config.get("worker.allowed-failures")): + for _ in range(config.get("worker.allowed-failures")): if await async_waiting_for( condition=lambda: self.node_controllers[node_id].response is True, @@ -416,7 +376,7 @@ async def async_gather(self) -> Dict: success.append(True) break else: - self.logger.debug( + self.logger.error( f"{self}: Node {node_id} responded to gather: FAILED" ) success.append(False) @@ -435,11 +395,15 @@ async def async_gather(self) -> Dict: return gather_data + @registry.on("collect", namespace=f"{__name__}.NodeHandlerService") async def async_collect(self) -> bool: # Request saving from Worker to Nodes - await self.eventbus.asend( - Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.REQUEST_COLLECT)) + await self.entrypoint.emit( + "broadcast", + ServerMessage( + signal=WORKER_MESSAGE.REQUEST_COLLECT, + ), ) # Now wait until all nodes have responded as CONNECTED @@ -457,7 +421,7 @@ async def async_collect(self) -> bool: success.append(True) break else: - self.logger.debug( + self.logger.error( f"{self}: Node {node_id} responded to saving request: FAILED" ) success.append(False) diff --git a/chimerapy/engine/worker/worker.py b/chimerapy/engine/worker/worker.py index 4463eaab..5d115287 100644 --- a/chimerapy/engine/worker/worker.py +++ b/chimerapy/engine/worker/worker.py @@ -1,3 +1,4 @@ +import asyncio import pathlib import shutil import tempfile @@ -5,17 +6,20 @@ import uuid from asyncio import Task from concurrent.futures import Future -from typing import Any, Coroutine, Dict, List, Literal, Optional, Tuple, Union +from typing import Coroutine, Dict, List, Literal, Optional, Union import asyncio_atexit +from aiodistbus import EntryPoint, EventBus, make_evented from chimerapy.engine import _logger, config -from ..eventbus import Event, EventBus, make_evented +from ..data_protocols import ConnectData from ..logger.zmq_handlers import NodeIDZMQPullListener from ..networking.async_loop_thread import AsyncLoopThread -from ..node import NodeConfig +from ..service import Service from ..states import NodeState, WorkerState + +# Services from .http_client_service import HttpClientService from .http_server_service import HttpServerService from .node_handler_service import NodeHandlerService @@ -65,6 +69,7 @@ def __init__( self.state = WorkerState(id=id, name=name, port=port, tempfolder=tempfolder) # Creating a container for task futures + self.services: List[Service] = [] self.task_futures: List[Future] = [] # Instance variables @@ -72,10 +77,20 @@ def __init__( self.shutdown_task: Optional[Task] = None async def aserve(self) -> bool: + """Start the Worker's services. + + This method will start the Worker's services, such as the HTTP + server and client, and the Node handler. It will also create + the event bus and the logging artifacts. + """ # Create the event bus for the Worker - self.eventbus = EventBus() - self.state = make_evented(self.state, event_bus=self.eventbus) + self.bus = EventBus() + self.entrypoint = EntryPoint() + await self.entrypoint.connect(self.bus) + + # Make the state evented + self.state = make_evented(self.state, bus=self.bus) # Create logging artifacts parent_logger = _logger.getLogger("chimerapy-engine-worker") @@ -85,33 +100,34 @@ async def aserve(self) -> bool: self.logreceiver = self._start_log_receiver() # Create the services - self.http_client = HttpClientService( - name="http_client", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, - logreceiver=self.logreceiver, + self.services.append( + HttpClientService( + name="http_client", + state=self.state, + logger=self.logger, + logreceiver=self.logreceiver, + ) ) - self.http_server = HttpServerService( - name="http_server", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, + self.services.append( + HttpServerService( + name="http_server", + state=self.state, + logger=self.logger, + ) ) - self.node_handler = NodeHandlerService( - name="node_handler", - state=self.state, - eventbus=self.eventbus, - logger=self.logger, - logreceiver=self.logreceiver, + self.services.append( + NodeHandlerService( + name="node_handler", + state=self.state, + logger=self.logger, + logreceiver=self.logreceiver, + ) ) - - await self.http_client.async_init() - await self.http_server.async_init() - await self.node_handler.async_init() + for service in self.services: + await service.attach(self.bus) # Start all services - await self.eventbus.asend(Event("start")) + await self.entrypoint.emit("start") self._alive = True # Register shutdown @@ -195,9 +211,9 @@ async def async_connect( self, host: Optional[str] = None, port: Optional[int] = None, - method: Optional[Literal["ip", "zeroconf"]] = "ip", + method: Literal["ip", "zeroconf"] = "ip", timeout: Union[int, float] = config.get("worker.timeout.info-request"), - ) -> bool: + ): """Connect ``Worker`` to ``Manager``. This establish server-client connections between ``Worker`` and @@ -213,47 +229,15 @@ async def async_connect( port (int): The ``Manager``'s port number timeout (Union[int, float]): Set timeout for the connection. - Returns: - bool: Success in connecting to the Manager + Raises: + TimeoutError: If the connection is not established within the \ """ - return await self.http_client.async_connect( - host=host, port=port, method=method, timeout=timeout - ) - - async def async_deregister(self) -> bool: - return await self.http_client.async_deregister() - - async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool: - return await self.node_handler.async_create_node(node_config) - - async def async_destroy_node(self, node_id: str) -> bool: - return await self.node_handler.async_destroy_node(node_id=node_id) + connect_data = ConnectData(method=method, host=host, port=port) + await asyncio.wait_for(self.entrypoint.emit("connect", connect_data), timeout) - async def async_start_nodes(self) -> bool: - return await self.node_handler.async_start_nodes() - - async def async_record_nodes(self) -> bool: - return await self.node_handler.async_record_nodes() - - async def async_step(self) -> bool: - return await self.node_handler.async_step() - - async def async_stop_nodes(self) -> bool: - return await self.node_handler.async_stop_nodes() - - async def async_request_registered_method( - self, node_id: str, method_name: str, params: Dict = {} - ) -> Dict[str, Any]: - return await self.node_handler.async_request_registered_method( - node_id=node_id, method_name=method_name, params=params - ) - - async def async_gather(self) -> Dict: - return await self.node_handler.async_gather() - - async def async_collect(self) -> bool: - return await self.node_handler.async_collect() + async def async_deregister(self): + await self.entrypoint.emit("deregister") async def async_shutdown(self) -> bool: @@ -264,7 +248,7 @@ async def async_shutdown(self) -> bool: self._alive = False # Shutdown all services and Wait until all complete - await self.eventbus.asend(Event("shutdown")) + await self.entrypoint.emit("shutdown") # Delete temp folder if requested if self.state.tempfolder.exists() and self.delete_temp: @@ -282,10 +266,10 @@ def connect( self, host: Optional[str] = None, port: Optional[int] = None, - method: Optional[Literal["ip", "zeroconf"]] = "ip", + method: Literal["ip", "zeroconf"] = "ip", timeout: Union[int, float] = 10.0, blocking: bool = True, - ) -> Union[bool, Future[bool]]: + ): """Connect ``Worker`` to ``Manager``. This establish server-client connections between ``Worker`` and @@ -300,8 +284,9 @@ def connect( timeout (Union[int, float]): Set timeout for the connection. blocking (bool): Make the connection call blocking. - Returns: - Future[bool]: Success in connecting to the Manager + Raises: + TimeoutError: If the connection is not established within the \ + timeout period. """ future = self._exec_coro(self.async_connect(host, port, method, timeout)) @@ -313,42 +298,6 @@ def connect( def deregister(self) -> Future[bool]: return self._exec_coro(self.async_deregister()) - def create_node(self, node_config: NodeConfig) -> Future[bool]: - return self._exec_coro(self.async_create_node(node_config)) - - def destroy_node(self, node_id: str) -> Future[bool]: - return self._exec_coro(self.async_destroy_node(node_id)) - - def step(self) -> Future[bool]: - return self._exec_coro(self.async_step()) - - def start_nodes(self) -> Future[bool]: - return self._exec_coro(self.async_start_nodes()) - - def record_nodes(self) -> Future[bool]: - return self._exec_coro(self.async_record_nodes()) - - def request_registered_method( - self, - node_id: str, - method_name: str, - params: Dict = {}, - ) -> Future[Tuple[bool, Any]]: - return self._exec_coro( - self.async_request_registered_method( - node_id=node_id, method_name=method_name, params=params - ) - ) - - def stop_nodes(self) -> Future[bool]: - return self._exec_coro(self.async_stop_nodes()) - - def gather(self) -> Future[Dict]: - return self._exec_coro(self.async_gather()) - - def collect(self) -> Future[bool]: - return self._exec_coro(self.async_collect()) - def idle(self): while self._alive: diff --git a/pyproject.toml b/pyproject.toml index 710f9174..13e40ecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ dependencies = [ 'networkx', 'dill', + 'aiodistbus', 'matplotlib', 'multiprocess', 'opencv-python', @@ -53,7 +54,6 @@ test = [ 'auto-changelog', 'coveralls', 'pre-commit', - 'docker', 'numpy', 'imutils', 'pillow', diff --git a/test/conftest.py b/test/conftest.py index 1b704db2..c1ef564e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,12 +7,11 @@ import time from typing import Dict -import docker import pytest +from aiodistbus import EntryPoint, EventBus import chimerapy.engine as cpe - -from .mock import DockeredWorker +from chimerapy.engine.networking.publisher import Publisher logger = cpe._logger.getLogger("chimerapy-engine") @@ -72,6 +71,29 @@ def event_loop(): loop.close() +@pytest.fixture +async def bus(): + bus = EventBus() + yield bus + await bus.close() + + +@pytest.fixture +async def entrypoint(bus): + entrypoint = EntryPoint() + await entrypoint.connect(bus) + yield entrypoint + await entrypoint.close() + + +@pytest.fixture +def pub(): + pub = Publisher() + pub.start() + yield pub + pub.shutdown() + + @pytest.fixture def logreceiver(): listener = cpe._logger.get_node_id_zmq_listener() @@ -103,71 +125,11 @@ async def worker(): await worker.async_shutdown() -@pytest.fixture -def docker_client(): - logger.info(f"DOCKER CLIENT: {current_platform}") - c = docker.DockerClient(base_url="unix://var/run/docker.sock") - return c - - @pytest.fixture(autouse=True) def disable_file_logging(): cpe.config.set("manager.logs-sink.enabled", False) -@pytest.fixture -def dockered_worker(docker_client): - logger.info(f"DOCKER WORKER: {current_platform}") - dockered_worker = DockeredWorker(docker_client, name="test") - yield dockered_worker - dockered_worker.shutdown() - - -class LowFrequencyNode(cpe.Node): - def setup(self): - self.i = 0 - - def step(self): - data_chunk = cpe.DataChunk() - if self.i == 0: - time.sleep(0.5) - self.i += 1 - data_chunk.add("i", self.i) - return data_chunk - else: - time.sleep(3) - self.i += 1 - data_chunk.add("i", self.i) - return data_chunk - - -class HighFrequencyNode(cpe.Node): - def setup(self): - self.i = 0 - - def step(self): - time.sleep(0.1) - self.i += 1 - data_chunk = cpe.DataChunk() - data_chunk.add("i", self.i) - return data_chunk - - -class SubsequentNode(cpe.Node): - def setup(self): - self.record = {} - - def step(self, data: Dict[str, cpe.DataChunk]): - - for k, v in data.items(): - self.record[k] = v - - data_chunk = cpe.DataChunk() - data_chunk.add("record", self.record) - - return data_chunk - - class GenNode(cpe.Node): def setup(self): self.value = 2 @@ -229,164 +191,3 @@ def graph(gen_node, con_node): _graph.add_edge(src=gen_node, dst=con_node) return _graph - - -@pytest.fixture -def single_node_no_connections_manager(manager, worker, gen_node): - - # Define graph - simple_graph = cpe.Graph() - simple_graph.add_nodes_from([gen_node]) - - # Connect to the manager - worker.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - simple_graph, - { - worker.id: [gen_node.id], - }, - ).result(timeout=30) - - return manager - - -@pytest.fixture -def multiple_nodes_one_worker_manager(manager, worker, gen_node, con_node): - - # Define graph - graph = cpe.Graph() - graph.add_nodes_from([gen_node, con_node]) - graph.add_edge(gen_node, con_node) - - # Connect to the manager - worker.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - graph, - { - worker.id: [gen_node.id, con_node.id], - }, - ).result(timeout=30) - - return manager - - -@pytest.fixture -def multiple_nodes_multiple_workers_manager(manager, gen_node, con_node): - - # Define graph - graph = cpe.Graph() - graph.add_nodes_from([gen_node, con_node]) - graph.add_edge(gen_node, con_node) - - worker1 = cpe.Worker(name="local", port=0) - worker2 = cpe.Worker(name="local2", port=0) - - worker1.connect(method="ip", host=manager.host, port=manager.port) - worker2.connect(method="ip", host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - graph, {worker1.id: [gen_node.id], worker2.id: [con_node.id]} - ).result(timeout=60) - - yield manager - - worker1.shutdown() - worker2.shutdown() - - -@pytest.fixture -def slow_single_node_single_worker_manager(manager, worker, slow_node): - - # Define graph - simple_graph = cpe.Graph() - simple_graph.add_nodes_from([slow_node]) - - # Connect to the manager - worker.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - simple_graph, - { - worker.id: [slow_node.id], - }, - ).result(timeout=30) - - return manager - - -@pytest.fixture -def dockered_single_node_no_connections_manager(dockered_worker, manager, gen_node): - - # Define graph - simple_graph = cpe.Graph() - simple_graph.add_nodes_from([gen_node]) - - # Connect to the manager - dockered_worker.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - simple_graph, - { - dockered_worker.id: [gen_node.id], - }, - ).result(timeout=30) - - return manager - - -@pytest.fixture -def dockered_multiple_nodes_one_worker_manager( - dockered_worker, manager, gen_node, con_node -): - - # Define graph - simple_graph = cpe.Graph() - simple_graph.add_nodes_from([gen_node, con_node]) - simple_graph.add_edge(gen_node, con_node) - - # Connect to the manager - dockered_worker.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - simple_graph, - { - dockered_worker.id: [gen_node.id, con_node.id], - }, - ).result(timeout=30) - - return manager - - -@pytest.fixture -def dockered_multiple_nodes_multiple_workers_manager( - docker_client, manager, gen_node, con_node -): - - # Define graph - graph = cpe.Graph() - graph.add_nodes_from([gen_node, con_node]) - graph.add_edge(gen_node, con_node) - - worker1 = DockeredWorker(docker_client, name="local") - worker2 = DockeredWorker(docker_client, name="local2") - - worker1.connect(host=manager.host, port=manager.port) - worker2.connect(host=manager.host, port=manager.port) - - # Then register graph to Manager - assert manager.commit_graph( - graph, {worker1.id: [gen_node.id], worker2.id: [con_node.id]} - ).result(timeout=30) - - yield manager - - worker1.shutdown() - worker2.shutdown() diff --git a/test/front_end_integration/__init__.py b/test/core/__init__.py similarity index 100% rename from test/front_end_integration/__init__.py rename to test/core/__init__.py diff --git a/test/networking/__init__.py b/test/core/networking/__init__.py similarity index 100% rename from test/networking/__init__.py rename to test/core/networking/__init__.py diff --git a/test/networking/test_client_server.py b/test/core/networking/test_client_server.py similarity index 98% rename from test/networking/test_client_server.py rename to test/core/networking/test_client_server.py index 8104dc5f..1c3c2804 100644 --- a/test/networking/test_client_server.py +++ b/test/core/networking/test_client_server.py @@ -12,11 +12,12 @@ import chimerapy.engine as cpe from chimerapy.engine.networking import Client, Server +from ...conftest import TEST_DIR + logger = cpe._logger.getLogger("chimerapy-engine") cpe.debug() # Constants -TEST_DIR = pathlib.Path(os.path.abspath(__file__)).parent.parent IMG_SIZE = 400 NUMBER_OF_CLIENTS = 5 diff --git a/test/networking/test_subscriber_publisher.py b/test/core/networking/test_subscriber_publisher.py similarity index 97% rename from test/networking/test_subscriber_publisher.py rename to test/core/networking/test_subscriber_publisher.py index 2ba4d21b..0156a567 100644 --- a/test/networking/test_subscriber_publisher.py +++ b/test/core/networking/test_subscriber_publisher.py @@ -78,5 +78,5 @@ def update(datas: Dict[str, bytes]): await subscriber.start() await publisher.publish(data_chunk.to_bytes()) - await asyncio.wait_for(flag.wait(), timeout=5) + await asyncio.wait_for(flag.wait(), timeout=15) assert expected_data_chunk == data_chunk diff --git a/test/test_async_loop_thread.py b/test/core/test_async_loop_thread.py similarity index 100% rename from test/test_async_loop_thread.py rename to test/core/test_async_loop_thread.py diff --git a/test/test_async_timer.py b/test/core/test_async_timer.py similarity index 100% rename from test/test_async_timer.py rename to test/core/test_async_timer.py diff --git a/test/test_data_chunk.py b/test/core/test_data_chunk.py similarity index 100% rename from test/test_data_chunk.py rename to test/core/test_data_chunk.py diff --git a/test/test_pipeline.py b/test/core/test_pipeline.py similarity index 100% rename from test/test_pipeline.py rename to test/core/test_pipeline.py diff --git a/test/manager/test_http_server_service.py b/test/manager/test_http_server_service.py index b4ddf881..1e6f1dd5 100644 --- a/test/manager/test_http_server_service.py +++ b/test/manager/test_http_server_service.py @@ -2,8 +2,8 @@ import aiohttp import pytest +from aiodistbus import make_evented -from chimerapy.engine.eventbus import EventBus, make_evented from chimerapy.engine.manager.http_server_service import HttpServerService from chimerapy.engine.states import ManagerState, WorkerState @@ -11,21 +11,19 @@ @pytest.fixture -async def http_server(): +async def http_server(bus): # Creating the configuration for the eventbus and dataclasses - event_bus = EventBus() - state = make_evented(ManagerState(), event_bus=event_bus) + state = make_evented(ManagerState(), bus=bus) # Create the services http_server = HttpServerService( name="http_server", port=0, enable_api=True, - eventbus=event_bus, state=state, ) - await http_server.async_init() + await http_server.attach(bus) await http_server.start() return http_server diff --git a/test/manager/test_manager.py b/test/manager/test_manager.py index ac72d599..22464e62 100644 --- a/test/manager/test_manager.py +++ b/test/manager/test_manager.py @@ -71,8 +71,8 @@ class TestLifeCycle: # for node_id in config_graph.G.nodes(): # assert manager.workers[_worker.id].nodes[node_id].fsm != "NULL" - @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) - # @pytest.mark.parametrize("context", ["multiprocessing"]) + # @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) + @pytest.mark.parametrize("context", ["multiprocessing"]) async def test_manager_lifecycle(self, manager_with_worker, context): manager, worker = manager_with_worker @@ -87,13 +87,13 @@ async def test_manager_lifecycle(self, manager_with_worker, context): mapping = {worker.id: [gen_node.id, con_node.id]} await manager.async_commit(graph, mapping, context=context) - assert await manager.async_start() - assert await manager.async_record() + await manager.async_start() + await manager.async_record() await asyncio.sleep(3) - assert await manager.async_stop() - assert await manager.async_collect() + await manager.async_stop() + await manager.async_collect() await manager.async_reset() @@ -113,13 +113,13 @@ async def test_manager_reset(self, manager_with_worker): await manager.async_reset() await manager.async_commit(graph=simple_graph, mapping=mapping) - assert await manager.async_start() - assert await manager.async_record() + await manager.async_start() + await manager.async_record() await asyncio.sleep(3) - assert await manager.async_stop() - assert await manager.async_collect() + await manager.async_stop() + await manager.async_collect() await manager.async_reset() @@ -141,18 +141,18 @@ async def test_manager_recommit_graph(self, manager_with_worker): logger.debug("STARTING COMMIT 1st ROUND") tic = time.time() - assert await manager.async_commit(**graph_info) + await manager.async_commit(**graph_info) toc = time.time() delta = toc - tic logger.debug("FINISHED COMMIT 1st ROUND") logger.debug("STARTING RESET") - assert await manager.async_reset() + await manager.async_reset() logger.debug("FINISHED RESET") logger.debug("STARTING COMMIT 2st ROUND") tic2 = time.time() - assert await manager.async_commit(**graph_info) + await manager.async_commit(**graph_info) toc2 = time.time() delta2 = toc2 - tic2 logger.debug("FINISHED COMMIT 2st ROUND") diff --git a/test/manager/test_worker_handler_service.py b/test/manager/test_worker_handler_service.py index fa71921d..272a021e 100644 --- a/test/manager/test_worker_handler_service.py +++ b/test/manager/test_worker_handler_service.py @@ -3,10 +3,10 @@ import tempfile import pytest +from aiodistbus import make_evented import chimerapy.engine as cpe -from chimerapy.engine import config -from chimerapy.engine.eventbus import Event, EventBus, make_evented +from chimerapy.engine.data_protocols import CommitData from chimerapy.engine.manager.http_server_service import HttpServerService from chimerapy.engine.manager.worker_handler_service import WorkerHandlerService from chimerapy.engine.states import ManagerState @@ -18,16 +18,13 @@ @pytest.fixture -async def testbed_setup(): +async def testbed_setup(bus, entrypoint): # Creating worker to communicate worker = cpe.Worker(name="local", id="local", port=0) await worker.aserve() - eventbus = EventBus() - state = make_evented( - ManagerState(logdir=pathlib.Path(tempfile.mkdtemp())), event_bus=eventbus - ) + state = make_evented(ManagerState(logdir=pathlib.Path(tempfile.mkdtemp())), bus=bus) # Define graph gen_node = GenNode(name="Gen1", id="Gen1") @@ -41,23 +38,20 @@ async def testbed_setup(): name="http_server", port=0, enable_api=True, - eventbus=eventbus, state=state, ) - worker_handler = WorkerHandlerService( - name="worker_handler", eventbus=eventbus, state=state - ) - await http_server.async_init() - await worker_handler.async_init() + worker_handler = WorkerHandlerService(name="worker_handler", state=state) + await http_server.attach(bus) + await worker_handler.attach(bus) - await eventbus.asend(Event("start")) + await entrypoint.emit("start") # Register worker await worker.async_connect(host=http_server.ip, port=http_server.port) yield (worker_handler, worker, simple_graph) - await eventbus.asend(Event("shutdown")) + await entrypoint.emit("shutdown") await worker.async_shutdown() @@ -116,9 +110,11 @@ async def test_worker_handler_lifecycle_graph(testbed_setup): # Register graph worker_handler._register_graph(simple_graph) - assert await worker_handler.commit( - graph=worker_handler.graph, mapping={worker.id: ["Gen1", "Con1"]} + commit_data = CommitData( + graph=worker_handler.graph, + mapping={worker.id: ["Gen1", "Con1"]}, ) + assert await worker_handler.commit(commit_data) assert await worker_handler.start_workers() await asyncio.sleep(1) @@ -128,31 +124,3 @@ async def test_worker_handler_lifecycle_graph(testbed_setup): # Teardown assert await worker_handler.reset() - - -async def test_worker_handler_enable_diagnostics(testbed_setup): - worker_handler, worker, simple_graph = testbed_setup - - config.set("diagnostics.interval", 2) - config.set("diagnostics.logging-enabled", True) - - # Register graph - worker_handler._register_graph(simple_graph) - - assert await worker_handler.commit( - graph=worker_handler.graph, mapping={worker.id: ["Gen1", "Con1"]} - ) - assert await worker_handler.start_workers() - await worker_handler.diagnostics(enable=True) - - await asyncio.sleep(4) - await worker_handler.diagnostics(enable=False) - - assert await worker_handler.stop() - assert await worker_handler.collect() - - # Teardown - assert await worker_handler.reset() - - session_folder = list(worker_handler.state.logdir.iterdir())[0] - assert (session_folder / "Con1" / "diagnostics.csv").exists() diff --git a/test/front_end_integration/test_ws.py b/test/manager/test_ws_client.py similarity index 77% rename from test/front_end_integration/test_ws.py rename to test/manager/test_ws_client.py index 53687892..55d91625 100644 --- a/test/front_end_integration/test_ws.py +++ b/test/manager/test_ws_client.py @@ -76,31 +76,6 @@ async def test_worker_network_updates(test_ws_client, manager, worker): assert record.network_state.to_json() == manager.state.to_json() -async def test_node_creation_and_destruction_network_updates( - test_ws_client, manager, worker -): - client, record = test_ws_client - - # Create original containers - simple_graph = cpe.Graph() - new_node = GenNode(name="Gen1", id="Gen1") - simple_graph.add_nodes_from([new_node]) - - # Connect to the manager - await worker.async_connect(host=manager.host, port=manager.port) - manager._register_graph(simple_graph) - - # Test construction - await manager._async_request_node_creation(worker_id=worker.id, node_id="Gen1") - await asyncio.sleep(2) - # assert record.network_state.to_json() == manager.state.to_json() - - # Test destruction - await manager._async_request_node_destruction(worker_id=worker.id, node_id="Gen1") - await asyncio.sleep(2) - assert record.network_state.workers[worker.id].nodes == {} - - async def test_reset_network_updates(test_ws_client, manager, worker): client, record = test_ws_client @@ -117,7 +92,7 @@ async def test_reset_network_updates(test_ws_client, manager, worker): assert record.network_state.to_json() == manager.state.to_json() # Reset - assert await manager.async_reset() + await manager.async_reset() await asyncio.sleep(3) assert record.network_state.to_json() == manager.state.to_json() diff --git a/test/manager/test_zeroconf_service.py b/test/manager/test_zeroconf_service.py index 364a0806..03a47f4c 100644 --- a/test/manager/test_zeroconf_service.py +++ b/test/manager/test_zeroconf_service.py @@ -5,7 +5,6 @@ import zeroconf from zeroconf import ServiceBrowser, ServiceInfo, ServiceListener, Zeroconf -from chimerapy.engine.eventbus import EventBus from chimerapy.engine.manager.zeroconf_service import ZeroconfService from chimerapy.engine.states import ManagerState @@ -42,13 +41,12 @@ def add_service(self, zeroconf, type, name): @pytest.fixture -async def zeroconf_service(): +async def zeroconf_service(bus): - eventbus = EventBus() state = ManagerState() - zeroconf_service = ZeroconfService("zeroconf", eventbus, state) - await zeroconf_service.async_init() + zeroconf_service = ZeroconfService("zeroconf", state) + await zeroconf_service.attach(bus) zeroconf_service.start() return zeroconf_service diff --git a/test/mock/__init__.py b/test/mock/__init__.py index f2c2d31c..e69de29b 100644 --- a/test/mock/__init__.py +++ b/test/mock/__init__.py @@ -1,5 +0,0 @@ -from .dockered_worker import DockeredWorker - -__all__ = [ - "DockeredWorker", -] diff --git a/test/mock/dockered_worker.py b/test/mock/dockered_worker.py deleted file mode 100644 index b77ab067..00000000 --- a/test/mock/dockered_worker.py +++ /dev/null @@ -1,76 +0,0 @@ -# Built-in Imports -import queue -import threading -import uuid - -# Third-party -import docker - -import chimerapy.engine as cpe - -logger = cpe._logger.getLogger("chimerapy-engine-networking") - - -class LogThread(threading.Thread): - def __init__(self, name: str, stream, output_queue: queue.Queue): - super().__init__() - - # Saving input parameters - self.name - self.stream = stream - self.output_queue = output_queue - - def __repr__(self): - return f"" - - def run(self): - - for data in self.stream: - logger.debug(f"{self}: {data.decode()}") - self.output_queue.put(data.decode()) - - -class DockeredWorker: - def __init__(self, client: docker.DockerClient, name: str): - self.container = client.containers.run( - image="chimerapy", - auto_remove=False, - stdin_open=True, - detach=True, - # network_mode="host", # Not realistic - ) - self.name = name - - # Create id - self.id: str = str(uuid.uuid4()) - - def connect(self, host, port): - - # Connect worker to Manager through entrypoint - _, stream = self.container.exec_run( - cmd=f"cpe-worker --id {self.id} --ip {host} --port {port} --name \ - {self.name} --wport 0", - stream=True, - ) - - # Execute worker connect - self.output_queue = queue.Queue() - self.log_thread = LogThread(self.name, stream, self.output_queue) - self.log_thread.start() - - # # Wait until the connection is established - while True: - - try: - data = self.output_queue.get(timeout=15) - except queue.Empty: - raise RuntimeError("Connection failed") - - if "connection successful to Manager" in data: - break - - def shutdown(self): - - # Then wait until the container is done - self.container.kill() - self.container.wait() diff --git a/test/streams/__init__.py b/test/node/streams/__init__.py similarity index 100% rename from test/streams/__init__.py rename to test/node/streams/__init__.py diff --git a/test/streams/data_nodes.py b/test/node/streams/data_nodes.py similarity index 100% rename from test/streams/data_nodes.py rename to test/node/streams/data_nodes.py diff --git a/test/streams/test_audio.py b/test/node/streams/test_audio.py similarity index 84% rename from test/streams/test_audio.py rename to test/node/streams/test_audio.py index 66f6b242..a058ce23 100644 --- a/test/streams/test_audio.py +++ b/test/node/streams/test_audio.py @@ -15,16 +15,15 @@ # Internal Imports import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.audio_record import AudioRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import AudioNode logger = cpe._logger.getLogger("chimerapy-engine") # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" CHUNK = 1024 FORMAT = pyaudio.paInt16 CHANNELS = 2 @@ -33,9 +32,17 @@ @pytest.fixture -def audio_node(): +def audio_node(logreceiver): # Create a node - an = AudioNode("an", CHUNK, CHANNELS, FORMAT, RATE, logdir=TEST_DATA_DIR) + an = AudioNode( + "an", + CHUNK, + CHANNELS, + FORMAT, + RATE, + logdir=TEST_DATA_DIR, + debug_port=logreceiver.port, + ) return an @@ -111,31 +118,30 @@ def test_audio_record(): assert expected_audio_path.exists() -async def test_node_save_audio_stream(audio_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_audio_stream(audio_node, bus, entrypoint): # Check that the audio was created expected_audio_path = pathlib.Path(audio_node.state.logdir) / "test.wav" - # try: - # os.remove(expected_audio_path) - # except FileNotFoundError: - # ... + try: + os.remove(expected_audio_path) + except FileNotFoundError: + ... # Stream - await audio_node.arun(eventbus=eventbus) + task = asyncio.create_task(audio_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await audio_node.ashutdown() + await task # Check that the audio was created assert expected_audio_path.exists() diff --git a/test/streams/test_image.py b/test/node/streams/test_image.py similarity index 82% rename from test/streams/test_image.py rename to test/node/streams/test_image.py index 2b6a5ee8..bdd8d8fb 100644 --- a/test/streams/test_image.py +++ b/test/node/streams/test_image.py @@ -10,16 +10,15 @@ # Internal Imports import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.image_record import ImageRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import ImageNode logger = cpe._logger.getLogger("chimerapy-engine") # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" @pytest.fixture @@ -58,10 +57,7 @@ def test_image_record(): assert expected_image_path.exists() -async def test_node_save_image_stream(image_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_image_stream(image_node, bus, entrypoint): # Check that the image was created expected_image_path = pathlib.Path(image_node.state.logdir) / "test" / "0.png" @@ -71,18 +67,20 @@ async def test_node_save_image_stream(image_node): ... # Stream - await image_node.arun(eventbus=eventbus) + task = asyncio.create_task(image_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await image_node.ashutdown() + await task # Check that the image was created assert expected_image_path.exists() diff --git a/test/streams/test_json.py b/test/node/streams/test_json.py similarity index 87% rename from test/streams/test_json.py rename to test/node/streams/test_json.py index f64b0a65..f8384baf 100644 --- a/test/streams/test_json.py +++ b/test/node/streams/test_json.py @@ -10,16 +10,15 @@ # Internal Imports import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.json_record import JSONRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import JSONNode logger = cpe._logger.getLogger("chimerapy-engine") # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" @pytest.fixture @@ -84,10 +83,7 @@ def test_image_record(): assert data_cp == data -async def test_node_save_json_stream(json_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_json_stream(json_node, bus, entrypoint): # Check that the image was created expected_jsonl_path = pathlib.Path(json_node.state.logdir) / "test.jsonl" @@ -97,18 +93,20 @@ async def test_node_save_json_stream(json_node): ... # Stream - await json_node.arun(eventbus=eventbus) + task = asyncio.create_task(json_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await json_node.ashutdown() + await task # Check that the image was created assert expected_jsonl_path.exists() diff --git a/test/streams/test_tabular.py b/test/node/streams/test_tabular.py similarity index 82% rename from test/streams/test_tabular.py rename to test/node/streams/test_tabular.py index a9ecc6b6..8f8027f0 100644 --- a/test/streams/test_tabular.py +++ b/test/node/streams/test_tabular.py @@ -9,9 +9,9 @@ import pytest import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.tabular_record import TabularRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import TabularNode # Internal Imports @@ -19,7 +19,6 @@ # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" @pytest.fixture @@ -57,10 +56,7 @@ def test_tabular_record(): assert expected_tabular_path.exists() -async def test_node_save_tabular_stream(tabular_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_tabular_stream(tabular_node, bus, entrypoint): # Check that the tabular was created expected_tabular_path = pathlib.Path(tabular_node.state.logdir) / "test.csv" @@ -70,18 +66,20 @@ async def test_node_save_tabular_stream(tabular_node): ... # Stream - await tabular_node.arun(eventbus=eventbus) + task = asyncio.create_task(tabular_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await tabular_node.ashutdown() + await task # Check that the tabular was created assert expected_tabular_path.exists() diff --git a/test/streams/test_text.py b/test/node/streams/test_text.py similarity index 86% rename from test/streams/test_text.py rename to test/node/streams/test_text.py index 1967396b..8764b8f0 100644 --- a/test/streams/test_text.py +++ b/test/node/streams/test_text.py @@ -9,16 +9,15 @@ # Internal Imports import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.text_record import TextRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import TextNode logger = cpe._logger.getLogger("chimerapy-engine") # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" @pytest.fixture @@ -69,10 +68,7 @@ def test_text_record(): assert line.strip() == (data[idx % len(data)]).strip() -async def test_node_save_text_stream(text_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_text_stream(text_node, bus, entrypoint): # Check that the image was created expected_text_path = pathlib.Path(text_node.state.logdir) / "test.text" @@ -82,18 +78,20 @@ async def test_node_save_text_stream(text_node): ... # Stream - await text_node.arun(eventbus=eventbus) + task = asyncio.create_task(text_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await text_node.ashutdown() + await task # Check that the image was created assert expected_text_path.exists() diff --git a/test/streams/test_video.py b/test/node/streams/test_video.py similarity index 88% rename from test/streams/test_video.py rename to test/node/streams/test_video.py index 4c7e8b7f..a048735a 100644 --- a/test/streams/test_video.py +++ b/test/node/streams/test_video.py @@ -11,9 +11,9 @@ import pytest import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.records.video_record import VideoRecord +from ...conftest import TEST_DATA_DIR from .data_nodes import VideoNode # Internal Imports @@ -23,7 +23,6 @@ # Constants CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent -TEST_DATA_DIR = CWD / "data" @pytest.fixture @@ -120,10 +119,7 @@ def test_video_record_with_unstable_frames(): assert (num_frames - expected_num_frames) / expected_num_frames <= 0.02 -async def test_node_save_video_stream(video_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_video_stream(video_node, bus, entrypoint): # Check that the video was created expected_video_path = pathlib.Path(video_node.state.logdir) / "test.mp4" @@ -133,18 +129,20 @@ async def test_node_save_video_stream(video_node): ... # Stream - await video_node.arun(eventbus=eventbus) + task = asyncio.create_task(video_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(3) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await video_node.ashutdown() + await task # Check that the video was created assert expected_video_path.exists() @@ -152,10 +150,7 @@ async def test_node_save_video_stream(video_node): cap.release() -async def test_node_save_video_stream_with_unstable_fps(video_node): - - # Event Loop - eventbus = EventBus() +async def test_node_save_video_stream_with_unstable_fps(video_node, bus, entrypoint): # Check that the video was created expected_video_path = pathlib.Path(video_node.state.logdir) / "test.mp4" @@ -169,18 +164,20 @@ async def test_node_save_video_stream_with_unstable_fps(video_node): rec_time = 5 # Stream - await video_node.arun(eventbus=eventbus) + task = asyncio.create_task(video_node.arun(bus=bus)) + await asyncio.sleep(1) # Wait to generate files - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(rec_time) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await video_node.ashutdown() + await task # Check that the video was created assert expected_video_path.exists() diff --git a/test/node/test_node.py b/test/node/test_node.py index 655eeeb4..7441fa01 100644 --- a/test/node/test_node.py +++ b/test/node/test_node.py @@ -3,14 +3,15 @@ import os import pathlib import time +from threading import Thread from typing import Type import pytest +from aiodistbus import EntryPoint, EventBus import chimerapy.engine as cpe from chimerapy.engine import config from chimerapy.engine.data_protocols import NodePubEntry, NodePubTable -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.networking.enums import WORKER_MESSAGE from chimerapy.engine.node.node_config import NodeConfig from chimerapy.engine.node.worker_comms_service import WorkerCommsService @@ -85,12 +86,7 @@ async def teardown(self): @pytest.fixture -async def eventbus(): - return EventBus() - - -@pytest.fixture -async def worker_comms_setup(mock_worker): +async def worker_comms2(mock_worker): config.set("diagnostics.logging-enabled", True) @@ -103,23 +99,44 @@ async def worker_comms_setup(mock_worker): worker_config=config.config, ) - return (worker_comms, mock_worker) + return worker_comms @pytest.mark.parametrize("node_cls", [StepNode, AsyncStepNode, MainNode, AsyncMainNode]) async def test_running_node_async_in_same_process( - logreceiver, node_cls: Type[cpe.Node], eventbus + logreceiver, node_cls: Type[cpe.Node], bus ): node = node_cls(name="step", debug_port=logreceiver.port) - await node.arun(eventbus=eventbus) + task = asyncio.create_task(node.arun(bus=bus)) await node.ashutdown() + await task @pytest.mark.parametrize("node_cls", [StepNode, AsyncStepNode, MainNode, AsyncMainNode]) -def test_running_node_in_same_process(logreceiver, node_cls: Type[cpe.Node], eventbus): +def test_running_node_in_same_process(logreceiver, node_cls: Type[cpe.Node], bus): node = node_cls(name="step", debug_port=logreceiver.port) - node.run(eventbus=eventbus) + thread = Thread(target=node.run, args=(bus,)) + thread.start() + node.shutdown() + thread.join() + + +@pytest.mark.parametrize("node_cls", [StepNode, AsyncStepNode, MainNode, AsyncMainNode]) +def test_running_node_in_process(logreceiver, node_cls: Type[cpe.Node]): + node = node_cls(name="step", debug_port=logreceiver.port) + + running = mp.Value("i", True) + p = mp.Process( + target=node.run, + args=( + None, + running, + ), + ) + p.start() node.shutdown() + running.value = False + p.join() @pytest.mark.parametrize( @@ -132,7 +149,7 @@ def test_running_node_in_same_process(logreceiver, node_cls: Type[cpe.Node], eve ], ) async def test_lifecycle_start_record_stop( - logreceiver, node_cls: Type[cpe.Node], eventbus + logreceiver, node_cls: Type[cpe.Node], bus, entrypoint ): # Create the node @@ -140,26 +157,29 @@ async def test_lifecycle_start_record_stop( # Running logger.debug(f"Running Node: {node_cls}") - await node.arun(eventbus=eventbus) + task = asyncio.create_task(node.arun(bus=bus)) + logger.debug(f"Outside Node: {node_cls}") + await asyncio.sleep(1) # Necessary to let the Node's services to startup # Wait - await eventbus.asend(Event("start")) + await entrypoint.emit("start") logger.debug("Finish start") await asyncio.sleep(0.5) - await eventbus.asend(Event("record")) + await entrypoint.emit("record") logger.debug("Finish record") await asyncio.sleep(0.5) - await eventbus.asend(Event("stop")) + await entrypoint.emit("stop") logger.debug("Finish stop") await asyncio.sleep(0.5) - await eventbus.asend(Event("collect")) + await entrypoint.emit("collect") logger.debug("Finish collect") logger.debug("Shutting down Node") await node.ashutdown() + await task @pytest.mark.parametrize( @@ -172,16 +192,15 @@ async def test_lifecycle_start_record_stop( ], ) async def test_node_in_process( - logreceiver, node_cls: Type[cpe.Node], worker_comms_setup + logreceiver, node_cls: Type[cpe.Node], worker_comms2, mock_worker ): - worker_comms, mock_worker = worker_comms_setup # Create the node node = node_cls(name="step", debug_port=logreceiver.port) id = node.id # Add worker_comms - node.add_worker_comms(worker_comms) + node.add_worker_comms(worker_comms2) running = mp.Value("i", True) # Adding shared variable that would be typically added by the Worker @@ -194,7 +213,7 @@ async def test_node_in_process( ) p.start() logger.debug(f"Running Node: {node_cls}") - await asyncio.sleep(0.5) + await asyncio.sleep(1) # Run method await mock_worker.server.async_send( @@ -218,6 +237,7 @@ async def test_node_in_process( await asyncio.sleep(0.25) node.shutdown() + running.value = False logger.debug("Shutting down Node") p.join() @@ -226,24 +246,24 @@ async def test_node_in_process( @linux_run_only @pytest.mark.parametrize("context", ["fork", "spawn"]) async def test_node_in_process_different_context( - logreceiver, worker_comms_setup, context + logreceiver, worker_comms2, context, mock_worker ): - worker_comms, mock_worker = worker_comms_setup # Create the node node = StepNode(name="step", debug_port=logreceiver.port) id = node.id # Add worker_comms - node.add_worker_comms(worker_comms) + node.add_worker_comms(worker_comms2) # Adding shared variable that would be typically added by the Worker + running = mp.Value("i", True) ctx = mp.get_context(context) p = ctx.Process( target=node.run, args=( None, - mp.Value("i", True), + running, ), ) p.start() @@ -271,6 +291,7 @@ async def test_node_in_process_different_context( await asyncio.sleep(0.25) node.shutdown() + running.value = False logger.debug("Shutting down Node") p.join() @@ -281,6 +302,10 @@ async def test_node_connection(logreceiver, mock_worker): # Create evenbus for each node g_eventbus = EventBus() c_eventbus = EventBus() + g_entrypoint = EntryPoint() + c_entrypoint = EntryPoint() + await g_entrypoint.connect(g_eventbus) + await c_entrypoint.connect(c_eventbus) # Create the node gen_node = GenNode(name="Gen1", debug_port=logreceiver.port, id="Gen1") @@ -310,8 +335,8 @@ async def test_node_connection(logreceiver, mock_worker): # Running logger.debug(f"Running Nodes: {gen_node.state}, {con_node.state}") - await gen_node.arun(eventbus=g_eventbus) - await con_node.arun(eventbus=c_eventbus) + g_task = asyncio.create_task(gen_node.arun(bus=g_eventbus)) + c_task = asyncio.create_task(con_node.arun(bus=c_eventbus)) logger.debug("Finish run") # Create the connections @@ -326,14 +351,16 @@ async def test_node_connection(logreceiver, mock_worker): logger.debug("Finish broadcast") # Wait - await g_eventbus.asend(Event("start")) - await c_eventbus.asend(Event("start")) + await g_entrypoint.emit("start") + await c_entrypoint.emit("start") logger.debug("Finish start") await asyncio.sleep(2) - await g_eventbus.asend(Event("stop")) - await c_eventbus.asend(Event("stop")) + await g_entrypoint.emit("stop") + await c_entrypoint.emit("stop") logger.debug("Finish stop") logger.debug("Shutting down Node") await gen_node.ashutdown() await con_node.ashutdown() + await g_task + await c_task diff --git a/test/node/test_poller_service.py b/test/node/test_poller_service.py index c204d29f..5ad716ec 100644 --- a/test/node/test_poller_service.py +++ b/test/node/test_poller_service.py @@ -4,9 +4,7 @@ import chimerapy.engine as cpe from chimerapy.engine.data_protocols import NodePubEntry, NodePubTable -from chimerapy.engine.eventbus import EventBus from chimerapy.engine.networking.data_chunk import DataChunk -from chimerapy.engine.networking.publisher import Publisher from chimerapy.engine.node.poller_service import PollerService from chimerapy.engine.states import NodeState @@ -14,10 +12,7 @@ @pytest.fixture -async def poller_setup(): - - # Event Loop - eventbus = EventBus() +async def poller(bus): # Create sample state state = NodeState() @@ -29,26 +24,17 @@ async def poller_setup(): in_bound_by_name=["pub_mock"], follow="pub_mock", state=state, - eventbus=eventbus, ) - await poller.async_init() - - pub = Publisher() - pub.start() - - yield (poller, pub) - + await poller.attach(bus) + yield poller await poller.teardown() - pub.shutdown() -async def test_instanticate(poller_setup): +async def test_instanticate(poller): ... -async def test_setting_connections(poller_setup): - - poller, pub = poller_setup +async def test_setting_connections(poller, pub): node_pub_table = NodePubTable( {"pub_mock": NodePubEntry(ip=pub.host, port=pub.port)} @@ -56,9 +42,7 @@ async def test_setting_connections(poller_setup): await poller.setup_connections(node_pub_table) -async def test_poll_message(poller_setup): - - poller, pub = poller_setup +async def test_poll_message(poller, pub): # Setup node_pub_table = NodePubTable( @@ -73,4 +57,4 @@ async def test_poll_message(poller_setup): # Sleep await asyncio.sleep(1) - assert poller.eventbus._event_counts > 0 + assert poller.emit_counter > 0 diff --git a/test/node/test_processor_service.py b/test/node/test_processor_service.py index b6ff4f6a..b2ea3cb5 100644 --- a/test/node/test_processor_service.py +++ b/test/node/test_processor_service.py @@ -6,9 +6,7 @@ from pytest_lazyfixture import lazy_fixture from chimerapy.engine import _logger -from chimerapy.engine.eventbus import Event, EventBus, TypedObserver from chimerapy.engine.networking.data_chunk import DataChunk -from chimerapy.engine.node.events import NewInBoundDataEvent, NewOutBoundDataEvent from chimerapy.engine.node.processor_service import ProcessorService from chimerapy.engine.states import NodeState @@ -40,25 +38,22 @@ async def shutdown(processor): await processor.teardown() -async def emit_data(eventbus): - for i in range(3): - await eventbus.asend( - Event("in_step", NewInBoundDataEvent({"data": DataChunk()})) - ) +async def emit_data(entrypoint): + await entrypoint.emit("start") + await asyncio.sleep(1) + for _ in range(3): + await entrypoint.emit("in_step", {"data": DataChunk()}) await asyncio.sleep(0.5) -async def receive_data(data_chunk): +async def receive_data(data_chunk: DataChunk): global RECEIVE_FLAG RECEIVE_FLAG = True logger.debug(data_chunk) @pytest.fixture -async def step_processor(): - - # Create eventbus - eventbus = EventBus() +async def step_processor(bus): # Create sample state state = NodeState() @@ -68,22 +63,16 @@ async def step_processor(): "processor", in_bound_data=True, state=state, - eventbus=eventbus, main_fn=step, operation_mode="step", ) - await processor.async_init() - - yield (processor, eventbus) - + await processor.attach(bus) + yield processor await processor.teardown() @pytest.fixture -async def source_processor(): - - # Create eventbus - eventbus = EventBus() +async def source_processor(bus): # Create sample state state = NodeState() @@ -93,22 +82,16 @@ async def source_processor(): "processor", in_bound_data=False, state=state, - eventbus=eventbus, main_fn=step, operation_mode="step", ) - await processor.async_init() - - yield (processor, eventbus) - + await processor.attach(bus) + yield processor await processor.teardown() @pytest.fixture -async def main_processor(): - - # Create eventbus - eventbus = EventBus() +async def main_processor(bus): # Create sample state state = NodeState() @@ -118,52 +101,47 @@ async def main_processor(): "processor", in_bound_data=True, state=state, - eventbus=eventbus, main_fn=main, operation_mode="main", ) - await processor.async_init() - - yield (processor, eventbus) - + await processor.attach(bus) + yield processor await processor.teardown() @pytest.mark.parametrize( - "processor_setup", + "processor", [ lazy_fixture("source_processor"), lazy_fixture("main_processor"), lazy_fixture("step_processor"), ], ) -def test_instanticate(processor_setup): +def test_instanticate(processor): ... @pytest.mark.parametrize( - "processor_setup", + "processor", [ lazy_fixture("source_processor"), lazy_fixture("main_processor"), lazy_fixture("step_processor"), ], ) -async def test_setup(processor_setup): - processor, _ = processor_setup +async def test_setup(processor): await processor.setup() @pytest.mark.parametrize( - "ptype, processor_setup", + "ptype, processor", [ ("source", lazy_fixture("source_processor")), ("main", lazy_fixture("main_processor")), ("step", lazy_fixture("step_processor")), ], ) -async def test_main(ptype, processor_setup): - processor, eventbus = processor_setup +async def test_main(ptype, processor, entrypoint): # Reset global CHANGE_FLAG @@ -173,20 +151,13 @@ async def test_main(ptype, processor_setup): # Adding observer for step if ptype == "step": - observer = TypedObserver( - "out_step", - NewOutBoundDataEvent, - on_asend=receive_data, - handle_event="unpack", - ) - await eventbus.asubscribe(observer) + await entrypoint.on("out_step", receive_data, DataChunk) # Execute await processor.setup() await asyncio.gather( - processor.main(), shutdown(processor), - emit_data(eventbus), + emit_data(entrypoint), ) # Asserts diff --git a/test/node/test_profiling_service.py b/test/node/test_profiler_service.py similarity index 70% rename from test/node/test_profiling_service.py rename to test/node/test_profiler_service.py index d08ddc6f..f82ebfb3 100644 --- a/test/node/test_profiling_service.py +++ b/test/node/test_profiler_service.py @@ -7,50 +7,39 @@ import chimerapy.engine as cpe from chimerapy.engine import config -from chimerapy.engine.eventbus import Event, EventBus from chimerapy.engine.networking.data_chunk import DataChunk -from chimerapy.engine.node.events import NewOutBoundDataEvent from chimerapy.engine.node.profiler_service import ProfilerService from chimerapy.engine.states import NodeState -# from ..conftest import TEST_DATA_DIR - - logger = cpe._logger.getLogger("chimerapy-engine") @pytest.fixture -async def profiler_setup(): +async def profiler(bus): # Modify the configuration config.set("diagnostics.interval", 1) config.set("diagnostics.logging-enabled", True) - # Event Loop - eventbus = EventBus() - # Create sample state state = NodeState(logdir=pathlib.Path(tempfile.mkdtemp())) # Create the profiler - profiler = ProfilerService( - name="profiler", state=state, eventbus=eventbus, logger=logger - ) - await profiler.async_init() + profiler = ProfilerService(name="profiler", state=state, logger=logger) + await profiler.attach(bus) await profiler.setup() - yield (profiler, eventbus) + yield profiler await profiler.teardown() -async def test_instanciate(profiler_setup): +async def test_instanciate(profiler): ... -async def test_single_data_chunk(profiler_setup): - profiler, eventbus = profiler_setup +async def test_single_data_chunk(profiler, entrypoint): await profiler.enable() - for i in range(50): + for _ in range(50): # Run the step multiple times example_data_chunk = DataChunk() @@ -62,19 +51,16 @@ async def test_single_data_chunk(profiler_setup): meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend( - Event("out_step", NewOutBoundDataEvent(example_data_chunk)) - ) + await entrypoint.emit("out_step", example_data_chunk) await profiler.diagnostics_report() assert profiler.log_file.exists() -async def test_single_data_chunk_with_multiple_payloads(profiler_setup): - profiler, eventbus = profiler_setup +async def test_single_data_chunk_with_multiple_payloads(profiler, entrypoint): await profiler.enable() - for i in range(50): + for _ in range(50): # Run the step multiple times example_data_chunk = DataChunk() @@ -87,18 +73,15 @@ async def test_single_data_chunk_with_multiple_payloads(profiler_setup): meta["value"]["delta"] = random.randrange(500, 1500, 1) example_data_chunk.update("meta", meta) - await eventbus.asend( - Event("out_step", NewOutBoundDataEvent(example_data_chunk)) - ) + await entrypoint.emit("out_step", example_data_chunk) await profiler.diagnostics_report() assert profiler.log_file.exists() -async def test_enable_disable(profiler_setup): - profiler, eventbus = profiler_setup +async def test_enable_disable(profiler, entrypoint): - for i in range(50): + for _ in range(50): # Run the step multiple times example_data_chunk = DataChunk() @@ -110,14 +93,12 @@ async def test_enable_disable(profiler_setup): meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend( - Event("out_step", NewOutBoundDataEvent(example_data_chunk)) - ) + await entrypoint.emit("out_step", example_data_chunk) assert len(profiler.seen_uuids) == 0 await profiler.enable(True) - for i in range(50): + for _ in range(50): # Run the step multiple times example_data_chunk = DataChunk() @@ -129,9 +110,7 @@ async def test_enable_disable(profiler_setup): meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend( - Event("out_step", NewOutBoundDataEvent(example_data_chunk)) - ) + await entrypoint.emit("out_step", example_data_chunk) await profiler.diagnostics_report() await profiler.enable(False) diff --git a/test/node/test_record_service.py b/test/node/test_record_service.py index 744a68f1..47ca4cc4 100644 --- a/test/node/test_record_service.py +++ b/test/node/test_record_service.py @@ -7,7 +7,6 @@ import pytest import chimerapy.engine as cpe -from chimerapy.engine.eventbus import EventBus from chimerapy.engine.node.record_service import RecordService from chimerapy.engine.states import NodeState @@ -15,20 +14,17 @@ @pytest.fixture -async def recorder(): - - # Event Loop - eventbus = EventBus() +async def recorder(bus): # Create sample state state = NodeState(logdir=pathlib.Path(tempfile.mkdtemp())) - state.fsm = "PREVIEWING" + state.fsm = "RECORDING" # Create the recorder - recorder = RecordService(name="recorder", state=state, eventbus=eventbus) - await recorder.async_init() + recorder = RecordService(name="recorder", state=state) + await recorder.attach(bus) yield recorder - await recorder.teardown() + recorder.teardown() async def test_instanciate(recorder): @@ -38,7 +34,7 @@ async def test_instanciate(recorder): async def test_record_direct_submit(recorder): # Run the recorder - await recorder.setup() + recorder.setup() timestamp = datetime.datetime.now() video_entry = { @@ -55,7 +51,7 @@ async def test_record_direct_submit(recorder): recorder.submit(video_entry) recorder.collect() - await recorder.teardown() + recorder.teardown() expected_file = recorder.state.logdir / "test.mp4" assert expected_file.exists() diff --git a/test/node/test_worker_comms.py b/test/node/test_worker_comms.py index ebeceece..06a7e9ea 100644 --- a/test/node/test_worker_comms.py +++ b/test/node/test_worker_comms.py @@ -1,11 +1,16 @@ +import asyncio from typing import Dict import pytest +from aiodistbus import make_evented from aiohttp import web import chimerapy.engine as cpe -from chimerapy.engine.data_protocols import NodeDiagnostics, NodePubTable -from chimerapy.engine.eventbus import EventBus +from chimerapy.engine.data_protocols import ( + NodeDiagnostics, + NodePubTable, + RegisteredMethodData, +) from chimerapy.engine.networking.enums import NODE_MESSAGE, WORKER_MESSAGE from chimerapy.engine.networking.server import Server from chimerapy.engine.node.node_config import NodeConfig @@ -34,7 +39,7 @@ async def setup(self): async def node_status_update(self, msg: Dict, ws: web.WebSocketResponse): - # self.logger.debug(f"{self}: note_status_update: ", msg) + # logger.debug(f"{self}: note_status_update: ", msg) node_state = NodeState.from_dict(msg["data"]) node_id = node_state.id @@ -63,13 +68,10 @@ async def mock_worker(): @pytest.fixture -async def worker_comms_setup(mock_worker): - - # Event Loop - eventbus = EventBus() +async def worker_comms(mock_worker, bus): # Create sample state - state = NodeState(id="test_worker_comms") + state = make_evented(NodeState(id="test_worker_comms"), bus=bus) node_config = NodeConfig() # Create the service @@ -79,18 +81,27 @@ async def worker_comms_setup(mock_worker): port=mock_worker.server.port, node_config=node_config, state=state, - eventbus=eventbus, logger=logger, ) - - yield (worker_comms, mock_worker.server) + await worker_comms.attach(bus) + await worker_comms.setup() + yield worker_comms + await worker_comms.teardown() await mock_worker.async_shutdown() -def test_instanticate(worker_comms_setup): +def test_instanticate(worker_comms): ... +async def test_node_state_change(worker_comms, mock_worker): + + # Change the state + worker_comms.state.fsm = "RUNNING" + await asyncio.sleep(1) + assert mock_worker.node_states[worker_comms.state.id].fsm == "RUNNING" + + @pytest.mark.parametrize( "method_name, method_params", [ @@ -98,7 +109,14 @@ def test_instanticate(worker_comms_setup): ("record_node", {}), ("stop_node", {}), ("provide_collect", {}), - ("execute_registered_method", {"data": {"method_name": "", "params": {}}}), + ( + "execute_registered_method", + { + "data": RegisteredMethodData( + node_id="1", method_name="a", params={} + ).to_dict() + }, + ), ("process_node_pub_table", {"data": NodePubTable().to_dict()}), ("async_step", {}), ("provide_gather", {}), @@ -106,19 +124,12 @@ def test_instanticate(worker_comms_setup): ("enable_diagnostics", {"data": {"enable": True}}), ], ) -async def test_methods(worker_comms_setup, method_name, method_params): - worker_comms, _ = worker_comms_setup - - # Start the server - await worker_comms.setup() +async def test_methods(worker_comms, method_name, method_params): # Run method method = getattr(worker_comms, method_name) await method(method_params) - # Shutdown - await worker_comms.teardown() - @pytest.mark.parametrize( "signal, data", @@ -134,16 +145,9 @@ async def test_methods(worker_comms_setup, method_name, method_params): (WORKER_MESSAGE.DIAGNOSTICS, {"enable": False}), ], ) -async def test_ws_signals(worker_comms_setup, signal, data): - worker_comms, server = worker_comms_setup - - # Start the server - await worker_comms.setup() +async def test_ws_signals(worker_comms, mock_worker, signal, data): # Run method - await server.async_send( + await mock_worker.server.async_send( client_id=worker_comms.state.id, signal=signal, data=data, ok=True ) - - # Shutdown - await worker_comms.teardown() diff --git a/test/test_docker.py b/test/test_docker.py deleted file mode 100644 index afafd131..00000000 --- a/test/test_docker.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -import chimerapy.engine as cpe - -from .conftest import linux_run_only - -logger = cpe._logger.getLogger("chimerapy-engine") - -# Resources: https://docker-py.readthedocs.io/en/stable/containers.html#docker.models.containers.Container -# https://stackoverflow.com/questions/61763684/following-the-exec-run-output-from-docker-py-in-realtime - - -@pytest.mark.skip(reason="Outdated") -@linux_run_only -def test_get_easy_docker_example_going(docker_client): - output = docker_client.containers.run("ubuntu", "echo $PATH") - logger.info(output) - - -@pytest.mark.skip(reason="Outdated") -@linux_run_only -def test_create_container_and_make_it_execute_commands(docker_client): - - # Create the docker container - container = docker_client.containers.run( - image="ubuntu", auto_remove=False, stdin_open=True, detach=True - ) - - # Start executing commands - output = container.exec_run(cmd="echo $(find /)") - logger.info(output) - - output = container.exec_run(cmd="whoami") - logger.info(output) - - -@pytest.mark.skip(reason="Outdated") -@linux_run_only -def test_use_custom_docker_image(docker_client): - - # Create the docker container - container = docker_client.containers.run( - image="chimerapy", auto_remove=False, stdin_open=True, detach=True - ) - - # Start executing commands - output = container.exec_run(cmd="python -c 'import chimerapy; print(chimerapy)'") - logger.info(output) diff --git a/test/test_eventbus.py b/test/test_eventbus.py deleted file mode 100644 index d371d825..00000000 --- a/test/test_eventbus.py +++ /dev/null @@ -1,328 +0,0 @@ -import asyncio -from dataclasses import dataclass -from typing import Any, Dict, List - -import pytest -from dataclasses_json import DataClassJsonMixin - -import chimerapy.engine as cpe -from chimerapy.engine.eventbus import ( - Event, - EventBus, - TypedObserver, - configure, - evented, - make_evented, -) -from chimerapy.engine.states import ManagerState, NodeState, WorkerState - -logger = cpe._logger.getLogger("chimerapy-engine") - - -@dataclass -class HelloEventData: - message: str - - -@dataclass -class WorldEventData: - value: int - - -@evented -@dataclass -class SomeClass(DataClassJsonMixin): - number: int - string: str - - -@dataclass -class NestedClass(DataClassJsonMixin): - number: int - subclass: HelloEventData - map: Dict[str, str] - vector: List[str] - - -@pytest.fixture -def event_bus(): - # Creating the configuration for the eventbus and dataclasses - event_bus = EventBus() - configure(event_bus) - return event_bus - - -async def test_msg_filtering(): - - event_bus = EventBus() - hello_observer = TypedObserver("hello", HelloEventData) - - # Subscribe to the event bus - await event_bus.asubscribe(hello_observer) - - # Create the event - hello_event = Event("hello", HelloEventData("Hello data")) - world_event = Event("world", WorldEventData(42)) - - # Send some events - await event_bus.asend(hello_event) - await event_bus.asend(world_event) - - assert world_event.id not in hello_observer.received - assert hello_event.id in hello_observer.received - - -async def test_event_null_data(): - - event_bus = EventBus() - null_observer = TypedObserver("null") - - # Subscribe to the event bus - await event_bus.asubscribe(null_observer) - - # Create the event - null_event = Event("null") - null2_event = Event("null2") - - # Send some events - await event_bus.asend(null_event) - await event_bus.asend(null2_event) - - assert null2_event.id not in null_observer.received - assert null_event.id in null_observer.received - - -async def test_subscribe_and_unsubscribe(): - - event_bus = EventBus() - null_observer = TypedObserver("null") - - # Subscribe to the event bus - await event_bus.asubscribe(null_observer) - - # Create the event - null_event = Event("null") - null2_event = Event("null") - - # Send some events - await event_bus.asend(null_event) - - # Unsubscribe and then send the event - await event_bus.aunsubscribe(null_observer) - await event_bus.asend(null2_event) - - assert null_event.id in null_observer.received - assert null2_event.id not in null_observer.received - - -async def test_awaitable_event(): - - event_bus = EventBus() - null_event = Event("null") - - async def later_event(): - await asyncio.sleep(1) - await event_bus.asend(null_event) - - asyncio.create_task(later_event()) - - null2_event = await event_bus.await_event("null") - assert null2_event == null_event - - -async def test_sync_and_async_binding(): - - event_bus = EventBus() - hello_observer = TypedObserver("hello", HelloEventData) - goodbye_observer = TypedObserver("goodbye", HelloEventData) - - # Creating handler - sync_local_variable: List = [] - async_local_variable: List = [] - - def add_to(var: List[Any]): - var.append(1) - - async def async_add_to(_): - async_local_variable.append(1) - - hello_observer.bind_asend(lambda _: add_to(sync_local_variable)) - goodbye_observer.bind_asend(async_add_to) - - # Subscribe to the event bus - await event_bus.asubscribe(hello_observer) - await event_bus.asubscribe(goodbye_observer) - - # Create the event - hello_event = Event("hello", HelloEventData("Hello data")) - goodbye_event = Event("goodbye", HelloEventData("Hello data")) - - # Send some events - await event_bus.asend(hello_event) - await event_bus.asend(goodbye_event) - - # Confirm - assert len(sync_local_variable) != 0 - assert len(async_local_variable) != 0 - - -async def test_event_handling(): - - event_bus = EventBus() - pass_observer = TypedObserver("hello", HelloEventData, handle_event="pass") - unpack_observer = TypedObserver("hello", HelloEventData, handle_event="unpack") - drop_observer = TypedObserver("hello", HelloEventData, handle_event="drop") - obs = [pass_observer, unpack_observer, drop_observer] - - # Creating handler - - pass_variable: List = [] - - async def pass_func(event): - assert isinstance(event, Event) - pass_variable.append(1) - - unpack_variable: List = [] - - async def unpack_func(message: str): - assert isinstance(message, str) - unpack_variable.append(1) - - drop_variable: List = [] - - async def drop_func(): - drop_variable.append(1) - - # Bind - pass_observer.bind_asend(pass_func) - unpack_observer.bind_asend(unpack_func) - drop_observer.bind_asend(drop_func) - - # Subscribe to the event bus - for ob in obs: - await event_bus.asubscribe(ob) - - # Send some events - await event_bus.asend(Event("hello", HelloEventData("Hello data"))) - - # Confirm - assert len(pass_variable) != 0 - assert len(unpack_variable) != 0 - assert len(drop_variable) != 0 - - -async def test_evented_dataclass(event_bus): - - # Creating the observer and its binding - evented_observer = TypedObserver("SomeClass.changed") - - # Creating handler - local_variable: List = [] - - async def add_to(event): - local_variable.append(1) - - evented_observer.bind_asend(add_to) - - # Subscribe to the event bus - await event_bus.asubscribe(evented_observer) - - # Create the evented class - data = SomeClass(number=1, string="hello") - - # Trigger an event by changing the class - data.number = 2 - await asyncio.sleep(1) - - # Confirm - assert len(local_variable) != 0 - assert isinstance(data.to_json(), str) - - -async def test_evented_wrapper(event_bus): - - # Creating the observer and its binding - evented_observer = TypedObserver("SomeClass.changed") - - # Creating handler - local_variable: List = [] - - async def add_to(event): - local_variable.append(1) - - evented_observer.bind_asend(add_to) - - # Subscribe to the event bus - await event_bus.asubscribe(evented_observer) - - # Create the evented class - data = make_evented(SomeClass(number=1, string="hello"), event_bus=event_bus) - - # Trigger an event by changing the class - logger.debug("Triggering manually") - data.number = 2 - await asyncio.sleep(1) - - # Confirm - assert len(local_variable) != 0 - assert isinstance(data.to_json(), str) - - -@pytest.mark.parametrize( - "cls, kwargs", - [ - (SomeClass, {"number": 1, "string": "hello"}), - (ManagerState, {}), - (WorkerState, {"id": "test", "name": "test"}), - (NodeState, {"id": "a"}), - ], -) -def test_make_evented(cls, kwargs, event_bus): - # Create the evented class - data = make_evented(cls(**kwargs), event_bus=event_bus) - data.to_json() - - -def test_make_evented_multiple(event_bus): - # Create the evented class - make_evented(SomeClass(number=1, string="hello"), event_bus=event_bus) - make_evented(SomeClass(number=1, string="hello"), event_bus=event_bus) - make_evented(SomeClass(number=1, string="hello"), event_bus=event_bus) - - -async def test_make_evented_nested(event_bus): - data_class = NestedClass( - number=1, - subclass=HelloEventData(message="hello"), - map={"test": "test"}, - vector=["hello", "there"], - ) - nested_data = make_evented( - data_class, - event_bus=event_bus, - ) - - logger.debug(data_class) - - nested_data.number = 5 - await asyncio.sleep(1) - a = event_bus._event_counts - assert a > 0 - - nested_data.map["new"] = "key" - await asyncio.sleep(1) - b = event_bus._event_counts - assert b > a - - nested_data.subclass.message = "goodbye" - await asyncio.sleep(1) - c = event_bus._event_counts - assert c > b - - nested_data.vector.append("this") - await asyncio.sleep(1) - d = event_bus._event_counts - assert d > c - - # Then it must also be jsonable - logger.debug(nested_data.to_json()) diff --git a/test/worker/node_handler/test_node_controller.py b/test/worker/node_handler/test_node_controller.py index 145750ee..789ebe03 100644 --- a/test/worker/node_handler/test_node_controller.py +++ b/test/worker/node_handler/test_node_controller.py @@ -1,4 +1,5 @@ import asyncio +from typing import List import multiprocess as mp @@ -33,6 +34,26 @@ async def test_mp_node_controller(): assert output == OUTPUT +async def test_mp_node_controller_multiple(): + mp_manager = mp.Manager() + session = MPSession() + + controllers: List[MPNodeController] = [] + for i in range(10): + node = GenNode(name=f"Gen{i}") + + node_controller = MPNodeController(node, logger) + controllers.append(node_controller) + node_controller.set_mp_manager(mp_manager) + node_controller.run(session) + + await asyncio.sleep(0.25) + + for c in controllers: + output = await c.shutdown() + assert output == OUTPUT + + async def test_thread_node_controller(): session = ThreadSession() node = GenNode(name="Gen1") diff --git a/test/worker/node_handler/test_node_handler.py b/test/worker/node_handler/test_node_handler.py index b72b5995..b71281f9 100644 --- a/test/worker/node_handler/test_node_handler.py +++ b/test/worker/node_handler/test_node_handler.py @@ -3,16 +3,18 @@ from typing import Optional, Union import pytest +from aiodistbus import EventBus, make_evented import chimerapy.engine as cpe -from chimerapy.engine.eventbus import Event, EventBus, make_evented +from chimerapy.engine import config +from chimerapy.engine.data_protocols import RegisteredMethodData from chimerapy.engine.states import WorkerState from chimerapy.engine.worker.http_server_service import HttpServerService from chimerapy.engine.worker.node_handler_service import NodeHandlerService from ...conftest import linux_run_only -from ...networking.test_client_server import server -from ...streams.data_nodes import ImageNode, TabularNode, VideoNode +from ...core.networking.test_client_server import server +from ...node.streams.data_nodes import ImageNode, TabularNode, VideoNode logger = cpe._logger.getLogger("chimerapy-engine") cpe.debug() @@ -72,13 +74,10 @@ def node_with_reg_methods(logreceiver): @pytest.fixture -async def node_handler_setup(): - - # Event Loop - eventbus = EventBus() +async def node_handler_setup(bus, entrypoint): # Requirements - state = make_evented(WorkerState(), event_bus=eventbus) + state = make_evented(WorkerState(), bus=bus) logger = cpe._logger.getLogger("chimerapy-engine-worker") log_receiver = cpe._logger.get_node_id_zmq_listener() log_receiver.start(register_exit_handlers=True) @@ -87,28 +86,26 @@ async def node_handler_setup(): node_handler = NodeHandlerService( name="node_handler", state=state, - eventbus=eventbus, logger=logger, logreceiver=log_receiver, ) - await node_handler.async_init() + await node_handler.attach(bus) # Necessary dependency - http_server = HttpServerService( - name="http_server", state=state, eventbus=eventbus, logger=logger - ) - await http_server.async_init() + http_server = HttpServerService(name="http_server", state=state, logger=logger) + await http_server.attach(bus) - await eventbus.asend(Event("start")) + await entrypoint.emit("start") yield (node_handler, http_server) - await eventbus.asend(Event("shutdown")) + await entrypoint.emit("shutdown") async def test_create_service_instance(node_handler_setup): ... -# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) +# @pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_create_node(gen_node, node_handler_setup, context): node_handler, _ = node_handler_setup @@ -157,8 +154,9 @@ def step(self): assert await node_handler.async_destroy_node(node_id) -@pytest.mark.skip() -@pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +# @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +@pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["threading"]) async def test_processing_node_pub_table( node_handler_setup, gen_node, con_node, context ): @@ -186,8 +184,8 @@ async def test_processing_node_pub_table( assert await node_handler.async_destroy_node(con_node.id) -# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) -@pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +@pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_starting_node(node_handler_setup, gen_node, context): node_handler, _ = node_handler_setup @@ -200,8 +198,8 @@ async def test_starting_node(node_handler_setup, gen_node, context): assert await node_handler.async_destroy_node(gen_node.id) -# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) -@pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +@pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_record_and_collect(node_handler_setup, context): node_handler, _ = node_handler_setup @@ -240,9 +238,14 @@ async def test_registered_method_with_concurrent_style( assert await node_handler.async_create_node(cpe.NodeConfig(node_with_reg_methods)) # Execute the registered method (with config) - results = await node_handler.async_request_registered_method( - node_id=node_with_reg_methods.id, method_name="printout" + # logger.debug(f"Requesting registered method") + reg_method_data = RegisteredMethodData( + node_id=node_with_reg_methods.id, + method_name="printout", ) + # logger.debug(f"Requesting registered method: {reg_method_data}") + results = await node_handler.async_request_registered_method(reg_method_data) + # logger.debug(f"Results: {results}") assert await node_handler.async_destroy_node(node_with_reg_methods.id) assert ( @@ -261,11 +264,12 @@ async def test_registered_method_with_params_and_blocking_style( assert await node_handler.async_create_node(cpe.NodeConfig(node_with_reg_methods)) # Execute the registered method (with config) - results = await node_handler.async_request_registered_method( + reg_method_data = RegisteredMethodData( node_id=node_with_reg_methods.id, method_name="set_value", params={"value": -100}, ) + results = await node_handler.async_request_registered_method(reg_method_data) assert await node_handler.async_destroy_node(node_with_reg_methods.id) assert ( @@ -284,10 +288,11 @@ async def test_registered_method_with_reset_style( assert await node_handler.async_create_node(cpe.NodeConfig(node_with_reg_methods)) # Execute the registered method (with config) - results = await node_handler.async_request_registered_method( + reg_method_data = RegisteredMethodData( node_id=node_with_reg_methods.id, method_name="reset", ) + results = await node_handler.async_request_registered_method(reg_method_data) assert await node_handler.async_destroy_node(node_with_reg_methods.id) @@ -298,8 +303,8 @@ async def test_registered_method_with_reset_style( ) -# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) -@pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +@pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_gather(node_handler_setup, gen_node, context): node_handler, _ = node_handler_setup @@ -315,3 +320,26 @@ async def test_gather(node_handler_setup, gen_node, context): assert len(results) > 0 assert await node_handler.async_destroy_node(gen_node.id) + + +@pytest.mark.parametrize("context", ["multiprocessing"]) +# @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) +# @pytest.mark.parametrize("context", ["threading"]) +async def test_diagnostics(node_handler_setup, gen_node, context): + + config.set("diagnostics.interval", 2) + config.set("diagnostics.logging-enabled", True) + + node_handler, _ = node_handler_setup + + assert await node_handler.async_create_node( + cpe.NodeConfig(gen_node, context=context) + ) + assert await node_handler.async_start_nodes() + assert await node_handler.async_diagnostics(True) + await asyncio.sleep(1) + assert await node_handler.async_stop_nodes() + + assert await node_handler.async_destroy_node(gen_node.id) + + assert (node_handler.state.tempfolder / gen_node.name / "diagnostics.csv").exists() diff --git a/test/worker/test_http_client.py b/test/worker/test_http_client.py index 7897727b..60362afa 100644 --- a/test/worker/test_http_client.py +++ b/test/worker/test_http_client.py @@ -1,11 +1,13 @@ import asyncio +import logging import os import shutil +from typing import Any import pytest +from aiodistbus import make_evented from chimerapy.engine import _logger -from chimerapy.engine.eventbus import Event, EventBus, make_evented from chimerapy.engine.networking.server import Server from chimerapy.engine.states import NodeState, WorkerState from chimerapy.engine.worker.http_client_service import HttpClientService @@ -26,14 +28,17 @@ async def server(): await server.async_shutdown() -@pytest.fixture -async def http_client(): +async def handler(*args, **kwargs): + logger.debug("Received data") + logger.debug(f"{args}, {kwargs}") + - # Event Loop - eventbus = EventBus() +@pytest.fixture +async def http_client(bus, entrypoint): # Requirements - state = make_evented(WorkerState(), event_bus=eventbus) + state = WorkerState() + state = make_evented(state, bus=bus) logger = _logger.getLogger("chimerapy-engine-worker") log_receiver = _logger.get_node_id_zmq_listener() log_receiver.start(register_exit_handlers=True) @@ -42,38 +47,32 @@ async def http_client(): http_client = HttpClientService( name="http_client", state=state, - eventbus=eventbus, logger=logger, logreceiver=log_receiver, ) - await http_client.async_init() + await http_client.attach(bus) yield http_client - - await eventbus.asend(Event("shutdown")) + await entrypoint.emit("shutdown") async def test_http_client_instanciate(http_client): ... -@pytest.mark.skip(reason="Manager not working") async def test_connect_via_ip(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) -@pytest.mark.skip(reason="Manager not working") async def test_connect_via_zeroconf(http_client, manager): await manager.async_zeroconf() assert await http_client._async_connect_via_zeroconf() -@pytest.mark.skip(reason="Manager not working") async def test_node_status_update(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) assert await http_client._async_node_status_update() -@pytest.mark.skip(reason="Manager not working") async def test_worker_state_changed_updates(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) @@ -86,6 +85,7 @@ async def test_worker_state_changed_updates(http_client, manager): # Check assert "test" in manager.state.workers[http_client.state.id].nodes + assert "test" in manager.state.workers[http_client.state.id].nodes async def test_send_archive_locally(http_client): diff --git a/test/worker/test_http_server.py b/test/worker/test_http_server.py index 850fb662..a63cc1cc 100644 --- a/test/worker/test_http_server.py +++ b/test/worker/test_http_server.py @@ -6,8 +6,12 @@ from pytest_lazyfixture import lazy_fixture from chimerapy.engine import _logger -from chimerapy.engine.data_protocols import NodeDiagnostics, NodePubTable -from chimerapy.engine.eventbus import EventBus +from chimerapy.engine.data_protocols import ( + GatherData, + NodeDiagnostics, + NodePubTable, + ResultsData, +) from chimerapy.engine.networking.client import Client from chimerapy.engine.networking.data_chunk import DataChunk from chimerapy.engine.networking.enums import NODE_MESSAGE @@ -26,20 +30,17 @@ def pickled_gen_node_config(gen_node): @pytest.fixture -async def http_server(): - - # Event Loop - eventbus = EventBus() +async def http_server(bus): # Requirements state = WorkerState() # Create the services - http_server = HttpServerService( - name="http_server", state=state, eventbus=eventbus, logger=logger - ) + http_server = HttpServerService(name="http_server", state=state, logger=logger) + await http_server.attach(bus) await http_server.start() - return http_server + yield http_server + await http_server.shutdown() @pytest.fixture @@ -103,18 +104,15 @@ async def test_http_server_routes(http_server, route_type, route, payload): (NODE_MESSAGE.STATUS, NodeState(logdir=None).to_dict()), ( NODE_MESSAGE.REPORT_GATHER, - { - "node_id": "test", - "latest_value": DataChunk().to_json(), - }, + GatherData(node_id="test", output=DataChunk().to_json()).to_dict(), ), ( NODE_MESSAGE.REPORT_RESULTS, - {"success": True, "output": 1, "node_id": "test"}, + ResultsData(node_id="test", success=True, output=None).to_dict(), ), ( NODE_MESSAGE.DIAGNOSTICS, - {"node_id": "test", "diagnostics": NodeDiagnostics().to_dict()}, + NodeDiagnostics().to_dict(), ), ], ) diff --git a/test/worker/test_worker.py b/test/worker/test_worker.py index 2d72ee1f..140fd9f5 100644 --- a/test/worker/test_worker.py +++ b/test/worker/test_worker.py @@ -1,7 +1,7 @@ import chimerapy.engine as cpe -from ..networking.test_client_server import server -from ..streams.data_nodes import AudioNode, ImageNode, TabularNode, VideoNode +from ..core.networking.test_client_server import server +from ..node.streams.data_nodes import AudioNode, ImageNode, TabularNode, VideoNode logger = cpe._logger.getLogger("chimerapy-engine") cpe.debug() @@ -29,3 +29,9 @@ async def test_worker_instance_async(): worker = cpe.Worker(name="local", id="local", port=0) await worker.aserve() await worker.async_shutdown() + + +def test_worker_instance_sync(): + worker = cpe.Worker(name="local", id="local", port=0) + worker.serve() + worker.shutdown()