diff --git a/.cruft.json b/.cruft.json index 481ce59..9975ef9 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,6 +1,6 @@ { "template": "git@github.com:UrbanMachine/create-ros-app.git", - "commit": "92da06d68aabc5ba36e559b7fdd83cf7a54a997f", + "commit": "5cbe2af623cd536fbb52420e3dc8e3205ff6c86e", "checkout": null, "context": { "cookiecutter": { @@ -9,13 +9,14 @@ "dockerhub_username_or_org": "urbanmachine", "project_name": "node_helpers", "project_description": "An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. Features intuitive APIs for parameter management, action handling, and error-resilient RPC. Designed by Urban Machine for safe and scalable robotics.", - "version": "1.0.0", + "version": "0.5.0", "license": "MIT", "example_package_name": "node_helpers", "example_package_description": "An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. Features intuitive APIs for parameter management, action handling, and error-resilient RPC. Designed by Urban Machine for safe and scalable robotics.", "example_node_name": "ExampleNode", "example_launch_profile": "node_helpers_showcase", - "example_package_version": "0.1.0", + "example_package_version": "0.5.0", + "__example_messages_package_name": "node_helpers_msgs", "_template": "git@github.com:UrbanMachine/create-ros-app.git" } }, diff --git a/.github/lint/main.py b/.github/lint/main.py index 35d1d6a..2882acf 100644 --- a/.github/lint/main.py +++ b/.github/lint/main.py @@ -21,8 +21,8 @@ LINTERS = { PYTHON_LANGUAGE: [ # Run linters from fastest to slowest - lint_ruff_check, lint_ruff_format, + lint_ruff_check, lint_darglint, lint_mypy, ], diff --git a/.github/lint/paths.py b/.github/lint/paths.py index 8eda43c..6a02790 100644 --- a/.github/lint/paths.py +++ b/.github/lint/paths.py @@ -11,8 +11,8 @@ def required_path(path_str: str) -> Path: ROS_PATH = required_path("pkgs") LINT_PATH = required_path(".github/lint") -JS_PATH = Path(".") -FIRMWARE_PATH = Path(".") +JS_PATH = Path() +FIRMWARE_PATH = Path() LAUNCH_PATH = required_path("launch-profiles") diff --git a/README.md b/README.md index 4e23a79..5da58ac 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,32 @@ An opinionated ROS2 framework that minimizes boilerplate while maximizing reliab ## Running This Project For in-depth documentation on the repository features, read the [About Template](docs/about_template.md) documentation. +This project is a collection of ROS utilities that play nicely together. It's +recommended to start by reading the highlights in ``docs/``. For smaller utilities, they +will be documented in READMEs in their respective modules. -### Dependencies +For example, ``node_helpers.timing`` has a README describing the API's in that module. +However ``node_helpers.parameters`` has a page under ``docs/`` that describes the +philosophy and usage of the module in depth. -This project depends on [Docker](https://docs.docker.com/get-docker/), and can be accelerated using [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). Install both before proceeding. +### Dependencies -The linting tooling requires [Poetry](https://python-poetry.org/docs/) to run. +This project depends on [Docker](https://docs.docker.com/get-docker/). The linting tooling requires [Poetry](https://python-poetry.org/docs/) to run. ### Running the Project +This project is intended to be used as a library, however there is a showcase +launch-profile that demonstrates some of the library's features. + To run the project, use the following command: ```shell docker/launch node_helpers_showcase ``` +Take a look at the nodes under `pkgs/node_helpers/nodes`, and the launch file under +`launch-profiles` to get an idea of some of the libraries features. + Then, open http://localhost/ on your browser to view the project logs. diff --git a/docker-compose.yaml b/docker-compose.yaml index 75a68f4..6e4947e 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -26,6 +26,8 @@ services: - "./launch-profiles/${LAUNCH_PROFILE:- 'set in docker/_shared'}:/robot/launch-profile/" # Necessary for display passthrough - "/tmp/.X11-unix:/tmp/.X11-unix:rw" + # Necessary for PulseAudio passthrough + - "/run/user/${USER_ID:-1000}/pulse/native:/pulse-socket" # Build cache, used by `save-build-cache` and `restore-build-cache` docker scripts - type: volume source: ros-nodes-build-cache @@ -33,6 +35,8 @@ services: environment: # Necessary for display passthrough DISPLAY: $DISPLAY + # Necessary for PulseAudio passthrough + PULSE_SERVER: "unix:/pulse-socket" # Gives the container access to kernel capabilities, useful for most robots network_mode: host cap_add: diff --git a/docker/Dockerfile b/docker/Dockerfile index da359da..b37c570 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -108,9 +108,6 @@ RUN echo "/opt/ros/${ROS2_DISTRO}/lib/${PYTHON_VERSION}/site-packages" >> /usr/l RUN echo "/opt/ros/${ROS2_DISTRO}/local/lib/${PYTHON_VERSION}/dist-packages" >> /usr/local/lib/${PYTHON_VERSION}/dist-packages/ros2.pth RUN make-pth-file-from-workspace "$(pwd)/install" /usr/local/lib/${PYTHON_VERSION}/dist-packages/robot.pth -# Copy in useful runtime scripts -COPY docker/utils/runtime/* /usr/local/bin/ - # Move the build cache from a Docker cache mount to a place where our build # system can see it. This helps make `colcon build` faster between runs. RUN --mount=type=cache,target="${BUILD_CACHE}" restore-build-cache diff --git a/docker/utils/environment/make-pth-file-from-workspace b/docker/utils/environment/make-pth-file-from-workspace index 60bdbe6..d539fcc 100755 --- a/docker/utils/environment/make-pth-file-from-workspace +++ b/docker/utils/environment/make-pth-file-from-workspace @@ -8,7 +8,7 @@ # make-pth-file-from-workspace {workspace install dir} {.pth file destination} # # Example: -# make-pth-file-from-workspace /robot/install /usr/local/lib/python3.8/dist-packages/robot.pth +# make-pth-file-from-workspace /robot/install /usr/local/lib/pythonX.XX/dist-packages/robot.pth set -o errexit set -o pipefail diff --git a/docs/about_template.md b/docs/about_template.md index 304d5ad..809df75 100644 --- a/docs/about_template.md +++ b/docs/about_template.md @@ -1,7 +1,7 @@ # Using `create-ros-app` This repository was initialized by the [create-ros-app](https://github.com/UrbanMachine/create-ros-app) template. -This template is a everything-but-the-robot-code starter for ROS projects. It includes a Dockerfile for building ROS packages, a GitHub Actions workflow for linting and autoformatting, and a few other goodies. +This template is a everything-but-the-robot-code starter for ROS projects. It includes a Dockerfile for building ROS packages, a GitHub Actions workflow for linting and autoformatting, and many other goodies. This documentation walks through the features of the template, and how to use them. @@ -56,15 +56,15 @@ Here's a quick guide on the features of this template The packages directory contains all the packages that are used in the project. Each package is added in the `Dockerfile`, and any new packages should be added there as well. -#### Package structure -Each package is made up of: +#### Python Package structure +Each python package is made up of: - A `resource` directory, which is a colcon requirement - A `package.xml` file, which is a colcon requirement - A `pyproject.toml`, because this project uses [colcon-poetry-ros](https://github.com/UrbanMachine/colcon-poetry-ros) to install project dependencies. Most ROS python packages use `setup.py`, but by using this plugin, we can use a modern python tool called [Poetry](https://python-poetry.org/) to manage dependencies. - A directory for code - A directory for tests -#### Test directories +##### Test directories As (arbitrary) best practice, the example node uses a test directory that follows the following structure @@ -84,6 +84,12 @@ package_name/ Essentially, tests exist in a parallel directory to the package, and are split into `unit` and `integration` tests. The directories within `unit` and `integration` mirror the structure of the package itself, except that module names are prefixed with `test_`. +#### Message Package Structure + +The template will generate a message package for you with an `ExampleAction`, `ExampleService`, and `ExampleMessage`. You can add more messages by adding them to the `msg` directory and updating the `CMakeLists.txt` and `package.xml` files. + +This can be used as a place for you to store your messages used just for this project. It follows standard ROS2 message package structure. + ### `.github/` This project uses poetry for linting, and has some code for running linting and autoformatting under `.github/lint`. diff --git a/docs/actions.rst b/docs/actions.rst new file mode 100644 index 0000000..c0dbd53 --- /dev/null +++ b/docs/actions.rst @@ -0,0 +1,284 @@ +Actions +=============== + +The ``node_helpers.actions`` framework is designed to provide an efficient and +easy-to-use structure for implementing and running actions in a ROS2 node. The purpose +is to facilitate DRY code, and the development of safe and robust action servers. + +Implementing Actions +-------------------- + +The two key API's exposed by ``node_helpers.actions`` that assist with the +implementation of actions are the ``ActionWorker`` and ``ActionHandler``. +All actions in the Urban Machine codebase are written using these API's. + +Let's start by reviewing two actions: one written the way an ordinary ROS action is +written, and another written using the ``node_helpers.actions`` framework. +Then, the differences between the two will be explained. + +Comparing Ordinary Actions to Node Helpers +****************************************** + +Example Ordinary ROS Action +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class CoolThingNode(Node): + def __init__(self): + super().__init__('cool_node') + self._action_server = ActionServer( + self, + CoolThing, + 'cool_thing', + self.action_callback + ) + + def action_callback(self, goal_handle: ServerGoalHandle) -> CoolThing.Result: + try: + self.get_logger().info('Executing goal...') + + while goal_not_reached: + # Publish some feedback + feedback_msg = CoolThing.Feedback() + goal_handle.publish_feedback(feedback_msg) + + # Check for cancellation + if goal_handle.is_cancel_requested: + put_robot_in_safe_state() + goal_handle.canceled() + result = CoolThing.Result() + return result + + # Make the robot do something + robot_do_thing() + + goal_handle.succeed() + + result = CoolThing.Result() + return result + except Exception as e: + put_robot_in_safe_state() + goal_handle.abort() + raise e + +Example node_helpers Action +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class CoolThingWorker( + ActionWorker[CoolThing.Goal, CoolThing.Feedback, CoolThing.Result] + ): + def __init__( + self, handler: "CoolThingActionHandler", goal_handle: ServerGoalHandle, logger: Logger, + ) -> None: + super().__init__(goal_handle, logger) + self.handler = handler + + def run(self) -> Generator[CoolThing.Feedback | None, None, None]: + """Action callback code goes here""" + self.get_logger().info('Executing goal...') + + while goal_not_reached: + # Yield some feedback (optional) + yield CoolThing.Feedback() + + # Yield `None` during times when it's okay for the action to cancel. + # The ActionWorker will automatically check for and handle cancellation. + yield None + + # Make the robot do something + robot_do_thing() + + # Set the result + self.result = CoolThing.Result(was_cool=True) + + def on_exception(self, ex: Exception) -> None: + """What to do if an unexpected exception happens.""" + put_robot_in_safe_state() + + def on_cancel(self) -> CoolThing.Result: + """Cancelation code goes here""" + put_robot_in_safe_state() + + return CoolThing.Result() + + class CoolThingActionHandler( + FailFastActionHandler[CoolThing.Goal, CoolThing.Feedback, CoolThing.Result] + ): + class Parameters(BaseModel): + some_parameter: float + + def __init__(self, node: HelpfulNode, action_name: str): + super().__init__(node=node, action_name=action_name, action_type=CoolThing) + self._params = node.declare_from_pydantic_model(self.Parameters, action_name) + + def create_worker(self, goal_handle: ServerGoalHandle) -> CoolThingWorker: + return CoolThingWorker(handler=self, goal_handle=goal_handle, logger=self.node.get_logger()) + +Anatomy of the Example +********************** + +CoolThingWorker is a subclass of ActionWorker and is responsible for executing the +action logic through its run method. The run method returns a generator that can yield +feedback messages (to be published) or None (to facilitate cancellation handling). + +In case of cancellation, the ``on_cancel`` method is called to handle it and safely transition +the robot out of the current action. If the ``run`` method raises an exception, the on_exception method is invoked. + +``CoolThingActionHandler`` is a subclass of ``FailFastActionHandler``. This means that it will +reject new action requests if the action is already in progress. The handler is responsible +for managing the action's parameters through a Pydantic Parameters class and creating +``CoolThingWorker`` instances to handle incoming action requests. + +The create_worker method in ``CoolThingActionHandler`` constructs a new ``CoolThingWorker`` +instance, providing the required parameters (handler, goal_handle, logger) for the execution +of the action. + +Key Differences +**************** + +Here are the key differences between the two examples: + +1. Clearer separation of concerns: node_helpers actions split the action logic out of + nodes, and into separate files, encouraging development of reusable actions for multiple nodes. +2. Metrics. Each action responded to by a ``ActionHandler`` will publish metrics to a + ``/metrics/write`` topic, which we use for visualizing actions in our ``Grafana`` dashboard. +3. Easier cancelation handling: node_helpers actions automatically handle cancelation + requests by using a generator, which simplifies the cancelation logic. +4. Easier error handling: node_helpers actions will always run ``on_exception`` in case + of an exception, reducing boilerplate. +5. Easily used alongside the ``node_helpers.parameters`` framework, allowing you to + request per-action configuration in the ``ActionHandler``, and then use it in the + ``ActionWorker``. +6. Takes care of calling the ``succeed``, ``abort``, and ``canceled`` methods of ``goal_handle`` + + +Other Action Handlers +--------------------- + +- ``ActionHandler`` the default base handler will take care of creating the ``ActionServer``, + and publishing metrics about the actions. +- ``FailFastActionHandler`` is useful for actions that the user knows should not run concurrently. + For example, if you have an action that moves an axis on a robot, it might be bad if two callers + were calling it at the same time. The ``FailFastActionHandler`` will quickly raise an exception, + and give the writer of the action peace of mind that it can't be used incorrectly. +- ``QueuedActionHandler`` is useful for actions that aren't safe to run in parallel, but are + safe to run sequentially in any order. For example, an action that takes a picture with a + camera. This handler will queue up requests and respond to them in order. +- ``ContextActionHandler`` is useful for defining actions that behave like python Context Managers. + This won't be covered in this tutorial. + +Running Actions +--------------- + +Running a Single Action +*********************** + +When running a single action, it's almost always preferred to use the default ROS methodology. +When running an action synchronously, almost always we use ``send_goal``: + +.. code-block:: python + + action_client.send_goal(CoolAction.Goal()) + +When running an action asynchronously, almost always the pattern is to run ``send_goal_async``, +wait for the goal to be accepted, then continue. Here's an example of **bad practice**: + +.. code-block:: python + + send_goal_future = action_client.send_goal_async(CoolAction.Goal()) + # Wait for the goal to be accepted by the server + while not send_goal_future.done(): + pass + goal_handle = send_goal_future.result() + + # Get the future for the goal result + result_future = goal_handle.get_result_async() + +That's so much boilerplate! There is a shorthand for type of sequence this under ``node_helpers.futures``: + +.. code-block:: python + + from node_helpers import futures + + result_future, client_handle = futures.wait_for_send_goal( + action_client, CoolAction.Goal() + ) + +In this example, the ``wait_for_send_goal`` waits for the goal to be accepted, then +returns the result future and the client handle. + +Running Many Actions +******************** + +This robot has a lot of moving parts, and a lot of actions for those moving parts. A common +pattern is to send multiple actions in parallel to different parts of the robot, requesting +that those parts start moving at the same time. + +There is a mini framework for making that sort of code easier to write. It consists of +``ActionElement``, ``ActionGroup``, and ``ActionSequence`` APIs. The most important is the +``ActionGroup``. + +ActionGroup +~~~~~~~~~~~ + +An ``ActionGroup`` is a utility class for running multiple actions simultaneously and +waiting for their completion or a specified partial completion. It can be thought of +as a framework for combining multiple actions into a single object which you can use +as though it were one action. + +It takes in a list of ``ActionElements``, which are just ``dataclasses`` that hold the +action client and the goal to run on it. + +Creating an ActionGroup +^^^^^^^^^^^^^^^^^^^^^^^ + +Below is an example of an action group being created. No action is run in this example. + +.. code-block:: python + + # Create an ActionGroup with ActionElements + action_group = ActionGroup( + ActionElement(client=action_client_1, goal=goal_1), + ActionElement(client=action_client_2, goal=goal_1, feedback_callback=on_feedback), + ) + +Running Synchronously +^^^^^^^^^^^^^^^^^^^^^ + +ActionGroups mirror the ROS ActionClient API, with some added features. Here's an example +of actually running the action group, and synchronously waiting for the results: + +.. code-block:: python + + results = action_group.send_goals() + +Running Asynchronously +^^^^^^^^^^^^^^^^^^^^^^ + +It's also desireable sometimes to run multiple actions, and not block. Here's how you +run all the actions asynchronously, and then choose when you block for results. + +.. code-block:: python + + # The following line starts all of the actions in the group in parallel + action_group.send_goals_async() + + # You can blockingly wait for results by using: + results = action_group.wait_for_results() + + # Alternatively, you can yield periodically until results arrive. This gives + # ActionWorkers a chance to check for cancellation, if running this action group + # within the context of a larger action + results = yield from action_group.yield_for_results() + +Cancelling Action Groups +^^^^^^^^^^^^^^^^^^^^^^^^ + +It's also possible to request cancellation of an entire action group at once: + +.. code-block:: python + + action_group.cancel_goals() \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..0273d9c --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,21 @@ +Node Helpers +=============== + +The ``node_helpers`` package consists of various frameworks +and utilities that help make writing ROS code more seamless. There is no +code specific to the Urban Machine robot here, rather, it's intended to be generically +useful for quickly building robust, reliable, and maintainable ROS nodes. + +In this documentation, the important modules within ``node_helpers`` will be documented, with some +explanation about their intended use within the larger framework. + +.. toctree:: + :caption: Contents: + + actions + index + launching + parameters + robust_rpc + sensors + about_template diff --git a/docs/launching.rst b/docs/launching.rst new file mode 100644 index 0000000..d6a7228 --- /dev/null +++ b/docs/launching.rst @@ -0,0 +1,55 @@ +Launching +========= + +The `node_helpers.launching` module provides utility functions and classes to streamline the management of ROS launch files. + +Core Features +------------- + +1. **Node Swapping**: + - The `SwappableNode` class allows nodes to track their name and namespace explicitly and facilitates dynamic swapping between "real" and "mock" nodes using the `SwapConfiguration` model. + - The `apply_node_swaps` function ensures a one-to-one mapping between nodes and their mocks, enforcing validation and consistency in swap configurations. + + **Example Usage**: + + .. code-block:: python + + from node_helpers import launching + + class MetaParameters: + swaps: dict[str, launching.SwapConfiguration] + + param_loader: ParameterLoader[MetaParameters] = ParameterLoader( + parameters_directory=Path("/robot/launch-profile/"), + override_file=Path("/robot/launch-profile/config.override.yaml"), + meta_parameters_schema=MetaParameters, + ) + + launch_description = [ + launching.SwappableNode(namespace="example_namespace", name="real_node"), + launching.SwappableNode(namespace="example_namespace", name="mock_node"), + ] + + filtered_launch = apply_node_swaps(param_loader.meta_parameters.swaps, launch_description) + + This loads parameters from yaml files in the launch-profile directory, applies the swap configuration to the launch description, and returns a filtered launch description with the correct nodes. + +2. **File Validation**: + - Utility functions such as `required_file` and `required_directory` verify the existence of critical files and directories before launching. These checks prevent runtime errors caused by missing resources. + + **Example Usage**: + + .. code-block:: python + + from node_helpers import launching + + config_file = launching.required_file("/path/to/config.yaml") + +Error Handling and Validation +----------------------------- + +The module includes robust error handling: + +- `InvalidSwapConfiguration` is raised when node swapping configurations are inconsistent or incomplete. +- Validation ensures all specified nodes are present in the launch description and that pairs of swappable nodes are correctly configured. + diff --git a/docs/parameters.rst b/docs/parameters.rst new file mode 100644 index 0000000..799c6af --- /dev/null +++ b/docs/parameters.rst @@ -0,0 +1,251 @@ +Parameters +=============== + +The parameters framework in our robot is designed to make managing and updating parameters in your ROS nodes efficient +and easy. This documentation will guide you through the process of defining parameters in your ROS nodes and managing +them using the layered parameter system. + +Why Not Use ROS Parameters Directly +----------------------------------- + +The Parameters framework uses ROS parameters behind the scene. Why are we using a framework in the first place? + +1) **Less boilerplate**: Simplify the process of declaring and getting parameters, reducing the + amount of boilerplate code you need to write when working with parameters. + +2) **Pydantic model integration**: Allow you to define parameters using Pydantic models, which provides + automatic validation and type checking for your parameters. This helps ensure that the parameters you define are + correct and adhere to the specified constraints. + +3) **Custom type parsing support**: The Parameters library has various custom types that you can use + to automatically convert from the type in the Yaml file to a specific python type. Want to parse that + string as a Path? Or as a specific Enum? No problem! + +4) **Required parameters**: Enable you to mark certain parameters as required, meaning that they must be + set from an external source (such as a parameter file). If a required parameter is not set, an + exception is raised, allowing you to detect and handle missing parameter values more + effectively. + +5) **Subscribe to parameter updates**: Offers a way to subscribe attributes to parameter updates. When a + parameter is updated, the corresponding attribute is automatically updated as well, making it easier to keep your + robot's state in sync with the parameter values. + + +Usage Examples +-------------- + +Basic Declaration of Pydantic Parameters +**************************************** + +Here's an example of using the ``Parameters`` framework within a ROS node. + +First off, you'll want to be working with a node that subclasses the ``ParameterMixin``. Typically in a codebase, the +``node_helpers.nodes.HelpfulNode`` will be used, because it provides all of the ``node_helper`` mixins at once, including +the ``ParameterMixin``. + +.. code-block:: python + + class CoolNode(HelpfulNode): + class Parameters(BaseModel): + wood_tf: str + + some_list_of_numbers: list[int] + + calibration_thing: float + + my_path: Path + + some_enum: MyCustomEnum + + maybe_int_maybe_str: int | str + + value_or_default_none: MyCustomEnum | None = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__("cool_node", **kwargs) + self.params = self.declare_from_pydantic_model(self.Parameters, "node") + +In the above example, if launched, the ``self.params`` variable will have all of the loaded parameters from the +file, type checked, validated, and required to be filled in. If the file doesn't have those parameters specified, it +will raise an exception and the runner of the stack will have a clear error message. + +Furthermore, these parameters can be changed live, during runtime, using tools like ``RQT``, because the ``ParameterMixin`` +automatically subscribes pydantic model attributes to changes in their respective parameters. For example, the +``params.calibration_thing`` attribute could be edited during runtime. + +The Override-This Value +^^^^^^^^^^^^^^^^^^^^^^^ + +There's special support for a sentinal value ````. If you set a parameter +to this value, it will raise an exception if the parameter is not set in the configuration file. + + +The 'Choosable' System +**************************** + +The ``Parameters`` framework also supports a 'choosable class' system, which allows you +to specify a parameter that will be used to select between different types of classes (or +instances of classes). + +A common pattern in our codebase is to have different implementations of some Base Class +which are used in different situations. For example, we might have a ``BaseCamera`` class +that has different implementations for different types of cameras. We might have a +``RealCamera`` class and a ``SimulatedCamera`` class, and we want to be able to choose +between them using a parameter. + +To do this, we can have the ``BaseCamera`` class inherit from +``Choosable`` type, which allows you to specify any cameras in pydantic +models, have have the type automatically found and returned. + +Basic Use +^^^^^^^^^ + +.. code-block:: python + + # Declaration of registered classes + class BaseCamera(Choosable): + # Normal base class stuff here + ... + + class RealCamera(BaseCamera): + # Camera implementation here + ... + + class SimulatedCamera(BaseCamera): + # Camera implementation here + ... + + # Using a choosable class in a pydantic model + class Parameters(BaseModel): + camera_type: type[BaseCamera] # <-- notice how this is wrapped with type[] + + # In a node somewhere + self.params = self.declare_from_pydantic_model(self.Parameters, "camera_params") + +Then in configuration, to set the value of camera_type, you can use the following: + +.. code-block:: yaml + + camera_params: + camera_type: "RealCamera" + +This type can then be accessed via ``self.params.camera_type``. + +Custom Names +^^^^^^^^^^^^ + +It can be desireable to have a different name for the parameter than the name of the +class. For example, we might want to have a camera be referred to as 'real_camera' +but the class name is 'RealCamera'. To do this, we could change the above example and +use metaclass parameters to specify the ``registered_name`` argument: + +.. code-block:: python + + class RealCamera(BaseCamera, registered_name="real_camera"): + pass + +Now in configuration this can be accessed like such: + +.. code-block:: yaml + + camera_params: + camera_type: "real_camera" + +Choosable Instances +^^^^^^^^^^^^^^^^^^^ + +Just like you can have choosable classes, you can also have choosable instances. This +allows you to 'register' instantiated objects and have them be accessible via a +parameter. + +This is most used in the ``urdf_data.constants`` package, where we have many different +URDFConstants instances that are referenced by name in the configuration. + +Here's an example of a Choosable Instance: + +.. code-block:: python + + class MyChoosableInstance(Choosable): + ... + + my_instance_a = MyChoosableInstance(cool_param=3) + my_instance_a.register_instance("instance_1") + + my_instance_b = MyChoosableInstance(cool_param=5) + my_instance_b.register_instance("instance_2") + + class Parameters(BaseModel): + my_instance: MyChoosableInstance # <-- notice how this isn't wrapped with type[] + + # In a node somewhere + self.params = self.declare_from_pydantic_model(self.Parameters, "my_params") + +Then in configuration, to set the value of my_instance, you can use the following: + +.. code-block:: yaml + + my_params: + my_instance: "instance_1" + +.. warning:: + + When registering an instance, this will hold a reference to the instance in memory! + It will never get garbage collected. + + In the future, we may want to change this feature to instead use weakrefs. + + + +Parameter Loading and Layered Parameters +---------------------------------------- + +Our parameter system is designed to be layered, which means it can load multiple YAML files and combine them. This +allows you to have different parameter configurations for different environments or situations, and easily switch +between them. + +The ``ParameterLoader`` is responsible for loading the parameter files and merging them together. When writing a launch +file, you can specify which parameter files should be loaded, and the ``ParameterLoader`` will take care of merging them +in the order specified. + +For information about how the parameters are configured for this robot, check out our `Configuration`_ docs + +.. _Configuration: ../../deployment/ros_configuration.html + +Meta Parameters +*************** + +In addition to the standard ROS parameters, our ``ParameterLoader`` system also supports an optional ``MetaParameters`` +field to be specified. ``MetaParameters`` allow you to insert a custom Pydantic model into an otherwise fully +ROS-parameter compatible YAML file. + +Why do this? Well, it's tremendously helpful when writing launch files to have configuration +for which nodes `the launch file will create` in the same place as the actual ROS parameter +configuration for the nodes themselves! + +For example, you might have a configuration file that looks like this: + +.. code-block:: yaml + + meta_parameters: + camera_namespaces: ["camera_1", "camera_2"] + urdfs: ["/path/to/urdf", "/path/to/another/urdf"] + + camera_1: + camera_node: + some_param: 3 + + camera_2: + camera_node: + some_param: 5 + +Then inside of the launch file you specify the MetaParameters using the following model: + +.. code-block:: python + + class MyCoolMetaParameters(BaseModel): + camera_namespace: list[str] + urdfs: list[str] + +Using this, you can use ``MyCoolMetaParameters`` when loading parameters using the ``ParameterLoader``, and extract +information from the configuration in order to dynamically generate all of the nodes specified under ``camera_namespaces``. + diff --git a/docs/robust_rpc.rst b/docs/robust_rpc.rst new file mode 100644 index 0000000..730c1c0 --- /dev/null +++ b/docs/robust_rpc.rst @@ -0,0 +1,207 @@ +Robust RPC +=============== + +The RobustRPC framework is one of the key components of ``node_helpers``. Its key API +is the ``RobustRPCMixin``, which provides a robust approach to handling errors in service +and action calls by propagating error messages raised by the server and re-raising them +on the client side. +This documentation aims to help users understand and effectively use the RobustRPCMixin. + + +What is it, and why use it? +--------------------------- +The RobustRPCMixin is a mixin class in the node_helpers ROS package designed to simplify error handling for remote procedure calls using services and actions. By integrating this mixin into your custom ROS nodes, you can improve error handling and maintainability in your robotic system. + +Key benefits of using RobustRPCMixin: + +- Automatically propagate error messages from the server to the client +- Pythonic error handling- you can 'try/except' remote errors +- Reduce boilerplate code (no more checking if every action succeeded or failed) + +Main Classes and Features +------------------------- + +1. ``RobustRPCMixin``: This mixin adds methods for creating service and action clients + and servers. It ensures that any errors raised by the server are caught, passed in + the response message, and re-raised on the client side. The client will raise an + error of the same name that subclasses ``RobustRPCException``. + +2. ``RobustActionClient``: A subclass of the ActionClient that wraps it to provide + functionality for raising remote errors when calling ``send_goal_async().result()`` or + ``send_goal()``. It includes the ``send_goal_as_context`` method, which is a context + manager that sends a goal and cancels it after the context ends. + +3. ``RobustServiceClient``: A subclass of the rclpy.Client that wraps it to provide + functionality for raising remote errors when calling call_async().result() or call(). + +4. Both the action and service clients will **automatically wait until the server is online!** + This means that upon the first call, they will wait until the server is online before + actually making the call. This fixes the annoying default ROS2 behavior of allowing + you to make the call, then quietly hanging and blocking forever. + + +What do I have to change to use it? +----------------------------------- + +Only two things! + +1. First, when instantiating a service or action, use the ``RobustRPCMixin`` methods + to do so. + +2. Use messages that contain the following: + + .. code-block:: text + + string error_name + string error_description + + That is to say, in an Action message, include those two fields in the ``result`` message. + In a service message, include those two fields in the ``response`` message. + +Server Usage Example +-------------------- + +In the example below, the ``YourNode`` class is a subclass of the ``Node`` and +``RobustRPCMixin`` classes. Both service and action servers are created using the +``create_robust_service`` and ``create_robust_action_server`` methods from the ``RobustRPCMixin``. + +The callback functions for the service and action server raise +ordinary python exceptions. When the exception is raised, it is magically captured by +the RobustRPCMixin wrappers, which pack the error name and description into the message, +and send that message back to the client. + +The robust client will then re-raise those exact exceptions. + +.. code-block:: python + + from rclpy.node import Node + from node_helpers.robust_rpc import RobustRPCMixin, RobustRPCException + from your_package.msg import YourRobustAction, YourRobustService + + + class YourNode(Node, RobustRPCMixin): + def __init__(self): + super().__init__("your_node") + + # Create robust service server + self.service_server = self.create_robust_service( + srv_type=YourRobustService, + srv_name="your_service_name", + callback=self.service_callback + ) + + # Create robust action server + self.action_server = self.create_robust_action_server( + action_type=YourRobustAction, + action_name="your_action_name", + execute_callback=self.action_callback + ) + + def service_callback(self, request: YourRobustService.Request, response: YourRobustService.Response): + # This behaves like an ordinary service callback, except that it's okay to + # raise python exceptions within it. + if something_bad: + raise SomeSpecificError("Oh no, something bad happened in my service!") + return response + + def action_callback(self, goal_handle: ServerGoalHandle): + # This behaves like an ordinary action callback, except that it's okay to + # raise python exceptions within it. + if something_bad: + raise SomeSpecificError("Oh no, something bad happened in my action!") + + goal_handle.succeed() + return result + +Client Usage Example +-------------------- + +In the example below, the client-side of the robust framework is demonstrated. +By using the ``create_robust_client()`` and ``create_robust_action_client()`` methods, +any exceptions raised in the server side will be re-raised on the client side. + +Furthermore, the client will automatically wait for the server to come online on the +first time it's called. On subsequent calls, this check will be skipped. + +The ``call_service()`` and ``call_action()`` methods show how to call a service and +call an action using these clients, with proper error handling using the +``RobustRPCException.like()`` method and the ``RobustRPCException`` class. + +.. code-block:: python + + from rclpy.node import Node + from node_helpers.robust_rpc import RobustRPCMixin + from your_package.msg import YourRobustAction, YourRobustService + + + class YourNode(Node, RobustRPCMixin): + def __init__(self): + super().__init__("your_node") + + # Create robust service client + self.service_client = self.create_robust_client( + srv_type=YourRobustService, srv_name="your_service_name" + ) + + # Create robust action client + self.action_client = self.create_robust_action_client( + action_type=YourRobustAction, action_name="your_action_name" + ) + + def call_service(self, request: YourRobustService.Request): + try: + response = self.service_client.call(request) + # Process response + except RobustRPCException.like("SomeSpecificError"): + # Handle specific error + except RobustRPCException as e: + # Handle all other robust RPC errors + + def call_action(self, goal: YourRobustAction.Goal): + + # Here is an example of the context-manager API of the action client. + # Upon entering the context, the action will be executed. Upon exiting the + # context, the action will be cancelled. + with self.action_client.send_goal_as_context(goal) as goal_handle: + # Perform tasks while goal is being executed + # ... + # The context manager will automatically cancel the goal if not finished + + # Here is an example of calling an action that can might raise exceptions + # remotely, and how to catch and handle those exceptions + try: + result_future = self.action_client.send_goal(goal) + # Process result + except RobustRPCException.like(SomeSpecificError): + # Handle this specific error type + except RobustRPCException as e: + # Handle all other robust RPC errors that are not "SomeSpecificError" + + +Catching Errors on the Client Side +********************************** + +When using a robust service or action client, errors raised on the server side will be +caught, passed in the response message, and re-raised on the client side. Because the +client side is raising exceptions from strings, all client side errors will subclass +``RobustRPCException``. To catch these errors by "name", you can use the +``RobustRPCException.like`` method: + +.. code-block:: python + + from node_helpers.robust_rpc import RobustRPCException + + try: + response = robust_service_client.call(request) + except RobustRPCException.like(ValueError): + # Handle a "ValueError" that was raised on the server side + ... + except RobustRPCException: + # Handle any other errors raised by the server + ... + +It is possible to use the `like()` method with a string value, instead of a reference +to an exception. This is typically bad practice, because the point of using a reference +is to aid when refactoring codebases, and discoverability of where that error might +be raised. If the exception gets renamed or removed, our lint tooling will catch the +error and alert the developer. \ No newline at end of file diff --git a/docs/sensors.rst b/docs/sensors.rst new file mode 100644 index 0000000..ddf3f8a --- /dev/null +++ b/docs/sensors.rst @@ -0,0 +1,79 @@ +Sensors +======= + +The `node_helpers.sensors` module is designed to standardize the way sensor data is handled in ROS2. By providing reusable publishers, buffers, and visualization tools, it streamlines the process of creating new sensors and integrating them into your application. + +Overview +-------- + +This module focuses on: + +1. **Publishers**: A structured way to publish sensor messages, ensuring consistent QoS settings and optional visualization support. +2. **Buffers**: Simple tools for holding and retrieving the latest sensor data. +3. **Predefined Sensors**: Ready-to-use components for common sensor types like binary signals and rangefinders. + +Core Components +--------------- + +### 1. BaseSensorPublisher + +The `BaseSensorPublisher` class provides a base for publishing sensor messages with the following features: + +- QoS settings optimized for sensor data (`qos_profile_sensor_data` by default). +- Support for RViz visualization via the `to_rviz_msg` method. +- Built-in throttling options for both sensor and visualization publishing rates. + +Example: + +.. code-block:: python + + from node_helpers.sensors.base_publisher import BaseSensorPublisher + from std_msgs.msg import Header + + class CustomSensor(BaseSensorPublisher): + def __init__(self, node, parameters): + super().__init__(node, msg_type=CustomMsg, parameters=parameters) + + def to_rviz_msg(self, msg): + # Implement visualization markers + pass + +### 2. BaseSensorBuffer + +The `BaseSensorBuffer` class provides a simple interface for subscribing to sensor topics and accessing the latest readings. It includes: + +- Event-based subscriptions via `on_value_change` and `on_receive`. +- Filtering of out-of-order messages. + +Example: + +.. code-block:: python + + from node_helpers.sensors.base_buffer import BaseSensorBuffer + + buffer = BaseSensorBuffer(node, CustomMsg, "/sensor/topic") + latest_reading = buffer.get() + +### 3. Predefined Sensors + +- **Binary Sensors**: Tools for creating sensors with binary outputs (e.g., on/off). +- **Rangefinders**: Components for publishing rangefinder data with visualization support. + +Binary Sensor Example: + +.. code-block:: python + + from node_helpers.sensors.binary_signal import BinarySensor + + binary_sensor = BinarySensor(node, parameters) + binary_sensor.publish_value(True) + +Rangefinder Example: + +.. code-block:: python + + from node_helpers.sensors.rangefinder import RangefinderPublisher + + rangefinder = RangefinderPublisher(node, parameters, qos) + rangefinder.publish_range(1.5) # Publish range in meters + diff --git a/launch-profiles/node_helpers_showcase/launcher.py b/launch-profiles/node_helpers_showcase/launcher.py index 83ce186..4e80f01 100644 --- a/launch-profiles/node_helpers_showcase/launcher.py +++ b/launch-profiles/node_helpers_showcase/launcher.py @@ -7,10 +7,6 @@ def generate_launch_description() -> LaunchDescription: rviz_config = "/robot/launch-profile/rviz-config.rviz" launch_description = [ - Node( - package="rviz2", - executable="rviz2", - arguments=["-d", [rviz_config]] - ), + Node(package="rviz2", executable="rviz2", arguments=["-d", [rviz_config]]), ] return LaunchDescription(launch_description) diff --git a/pkgs/node_helpers/README.md b/pkgs/node_helpers/README.md index 4d99b03..0ddce99 100644 --- a/pkgs/node_helpers/README.md +++ b/pkgs/node_helpers/README.md @@ -1,3 +1,5 @@ # node_helpers -An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. Features intuitive APIs for parameter management, action handling, and error-resilient RPC. Designed by Urban Machine for safe and scalable robotics. \ No newline at end of file +An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. +Features intuitive APIs for parameter management, action handling, and error-resilient RPC. +Designed by Urban Machine for safe and scalable robotics. \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/actions/README.md b/pkgs/node_helpers/node_helpers/actions/README.md new file mode 100644 index 0000000..c681a40 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/README.md @@ -0,0 +1,14 @@ +# node_helpers.actions + +This module provides a robust framework for creating and managing ROS 2 actions, centered around two key APIs: ActionWorker and ActionHandler. These classes form the foundation of the module, enabling developers to define custom action logic, handle feedback and cancellations, and integrate seamlessly with ROS's action server infrastructure. + +**`ActionWorker`**: +An abstract base class for implementing action execution logic with built-in support for feedback, timeouts, and cancellations. + + +**`ActionHandler`**: +A configurable handler that manages an action server, delegates goals to ActionWorker instances, and provides hooks for metrics and custom behavior. +This module also includes tools for sequencing actions, managing complex workflows, and enforcing safe execution patterns, such as fail-fast behavior or context-managed long-running actions. + + +The framework is documented in [docs/](../../../../docs/actions.rst). \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/actions/__init__.py b/pkgs/node_helpers/node_helpers/actions/__init__.py new file mode 100644 index 0000000..2e0a97b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/__init__.py @@ -0,0 +1,10 @@ +from .action_sequences import ( + ActionElement, + ActionGroup, + ActionSequence, + AlreadyRunningActionsError, + NoRunningActionsError, + ParallelActionSequences, +) +from .context_manager import ActionContextManager +from .generators import generator_sleep diff --git a/pkgs/node_helpers/node_helpers/actions/action_sequences.py b/pkgs/node_helpers/node_helpers/actions/action_sequences.py new file mode 100644 index 0000000..3193c2c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/action_sequences.py @@ -0,0 +1,397 @@ +import itertools +from collections.abc import Callable, Generator, Iterable +from dataclasses import dataclass +from threading import Event +from typing import Any, Generic, TypeVar + +from action_msgs.srv import CancelGoal +from rclpy import Future +from rclpy.action.client import ClientGoalHandle + +from node_helpers import futures +from node_helpers.futures import ExceptionCollector +from node_helpers.robust_rpc import RobustActionClient + +FeedbackCallback = Callable[[Any], bool] +GOAL = TypeVar("GOAL") +T = TypeVar("T") + + +@dataclass +class ActionElement(Generic[GOAL]): + client: RobustActionClient + + goal: GOAL + + feedback_callback: FeedbackCallback | None = None + """A function that will be hooked up to the feedback callback. The input argument is + the Feedback message object and the return value should be True if the action is + ready to trigger the next ActionGroup in the ActionSequence. + """ + + +class NoRunningActionsError(Exception): + """Raised when an operation is to cancel or wait for results, but there are no + running actions""" + + +class AlreadyRunningActionsError(Exception): + """Raised when trying to execute on a group that's already running actions""" + + +def _normalize_iterable(*elements: Iterable[T] | T) -> list[T]: + """Normalize an iterable of ActionElements into a list""" + return list( + itertools.chain.from_iterable( + list(e) if isinstance(e, Iterable) else [e] for e in elements + ) + ) + + +class ActionGroup: + """A helper utility for running multiple actions simultaneously, and waiting for + completion (or a specified partial completion). + """ + + @dataclass + class _RunningAction: + result_future: Future + + goal_handle: ClientGoalHandle + """The goal_handle for the ongoing action""" + + next_action_trigger: Event + """Set either when the action is finished, or when the feedback_callback has + returned True. + """ + + def __init__( + self, *elements: ActionElement[GOAL] | Iterable[ActionElement[GOAL]] + ) -> None: + """Creates the action group and begins the actions + :param elements: Either a list of elements, or iterables of elements. + For example, ActionGroup(element_a, element_b) is acceptable, as is + ActionGroup(ActionElement(..) for client, goal in clients_and_goals) + :raises ValueError: If no items were passed in to elements. + """ + if len(elements) == 0: + # This might be reconsidered, if there's a use case for not raising an error + raise ValueError("There must be at least one element!") + + # Unpack elements or iterables of elements into a single flat list + self._action_elements: list[ActionElement[GOAL]] = _normalize_iterable( + *elements + ) + self._running_actions: list[ActionGroup._RunningAction] = [] + + # APIs that somewhat mirror Action Client + def send_goals_async(self) -> None: + """Asynchronously begin running all actions""" + if len(self._running_actions): + raise AlreadyRunningActionsError("This function can only be called once!") + + # Start all the actions async and hook up feedback callbacks + handle_futures: list[Future] = [] + next_action_trigger_events: list[Event] = [] + for element in self._action_elements: + on_feedback: Callable[[Any], None] | None + next_action_trigger = Event() + if element.feedback_callback is not None: + + def on_feedback( + feedback: Any, + callback: FeedbackCallback = element.feedback_callback, # type: ignore + trigger: Event = next_action_trigger, + ) -> None: + if callback(feedback.feedback): + trigger.set() + + else: + on_feedback = None + + handle_futures.append( + element.client.send_goal_async( + goal=element.goal, feedback_callback=on_feedback + ) + ) + next_action_trigger_events.append(next_action_trigger) + + # Sync the futures such that now we know all actions were accepted + handles = futures.wait_for_futures(handle_futures, ClientGoalHandle) + + # Hook up each action so when they finish the feedback_trigger is also called + for handle, trigger in zip(handles, next_action_trigger_events, strict=False): + result_future: Future = handle.get_result_async() + result_future.add_done_callback(lambda f, trigger=trigger: trigger.set()) + action = self._RunningAction( + next_action_trigger=trigger, + result_future=result_future, + goal_handle=handle, + ) + self._running_actions.append(action) + + def send_goals(self) -> list[Any]: + """Synchronously run all actions""" + self.send_goals_async() + return self.wait_for_results() + + def cancel_goals(self) -> None: + """Synchronously cancels all (if any) running actions, then blocks until done""" + if not len(self._running_actions): + return + + self.cancel_goals_async() + self.wait_for_results() + + def cancel_goals_async(self) -> None: + """Synchronously cancels all (if any) running actions, then don't wait""" + if not len(self._running_actions): + return + + # Asynchronously trigger cancellation on all running actions + cancellation_acceptance_futures = [ + handler.goal_handle.cancel_goal_async() for handler in self._running_actions + ] + + # Wait until all cancellation requests have been (presumably) accepted + # TODO: If there's ever a need for it, it may be worth checking that the actions + # that were cancelled actually accepted the cancellation response. + futures.wait_for_futures( + cancellation_acceptance_futures, type_=CancelGoal.Response, timeout=30 + ) + + # Additional APIs useful for groups, that don't mirror ActionClient APIs + def yield_for_results( + self, yield_interval: float = 0.1 + ) -> Generator[None, None, list[Any]]: + """This is a helpful method that is equivalent to wait_for_results except it + yields periodically. + + Usage example in an ActionWorker.run loop: + >>> group.send_goals_async() + >>> # The following line allows cancellation while waiting for actions to finish + >>> action_results = yield from group.yield_for_results() + + :param yield_interval: How often to yield, while waiting for results + :yields: None, until all futures are complete + :return: The results of all the actions in the action group + :raises NoRunningActionsError: If no actions are running + """ + + if len(self._running_actions) == 0: + raise NoRunningActionsError( + "You must start the actions before waiting on results!" + ) + + results = yield from futures.yield_for_futures( + [a.result_future for a in self._running_actions], object, yield_interval + ) + + self._running_actions = [] + return [r.result for r in results] # type: ignore + + def wait_for_results(self) -> list[Any]: + """Wait until all actions have returned a result""" + + if len(self._running_actions) == 0: + raise NoRunningActionsError( + "You must start the actions before waiting on results!" + ) + + results = futures.wait_for_futures( + [r.result_future for r in self._running_actions], object + ) + self._running_actions = [] + + # Return the results *.result method, instead of its {Action}_GetResult_Response + return [r.result for r in results] # type: ignore + + def wait_for_feedback_triggers(self) -> None: + """Wait until all actions 'feedback_trigger' events are set""" + if len(self._running_actions) == 0: + raise NoRunningActionsError( + "You must start the actions before waiting on triggers!" + ) + + for running_action in self._running_actions: + running_action.next_action_trigger.wait() + + # Raise exception if any of the actions have already finished and failed + self._check_for_exceptions() + + def yield_for_feedback_triggers(self) -> Generator[None, None, None]: + """Yield until all actions 'feedback_trigger' events are set""" + if len(self._running_actions) == 0: + raise NoRunningActionsError( + "You must start the actions before waiting on triggers!" + ) + + while not all( + running_action.next_action_trigger.is_set() + for running_action in self._running_actions + ): + yield + + # Raise exception if any of the actions have already finished and failed + self._check_for_exceptions() + + def _check_for_exceptions(self) -> None: + """Check for exceptions in the running actions""" + collector = ExceptionCollector() + for running_action in self._running_actions: + if running_action.result_future.done(): + with collector: + running_action.result_future.result() + + if collector.exceptions: + self.wait_for_results() + collector.maybe_raise() + + +class ActionSequence: + """Execute a list of ActionGroups. + + Usage Example: + >>> sequence = ActionSequence( + >>> ActionGroup( + >>> ActionElement(client, goal, feedback_callback=ready_to_continue), + >>> ActionElement(client, goal, feedback_callback=ready_to_continue), + >>> ), + >>> ActionGroup( + >>> ActionElement(client, goal), + >>> ) + >>>) + + In the above code, the first action group will run in parallel until all the + 'ready_to_continue' functions set their respective feedback_trigger events, at + which point the next ActionGroup will begin to run in parallel. + + If no 'ready_to_continue' function is passed, then the ActionGroups will be run + one after the other, synchronously (waiting for results from all child actions + within the group before continuing on to the next group). + """ + + def __init__(self, *action_groups: ActionGroup | Iterable[ActionGroup]) -> None: + self._action_groups = _normalize_iterable(*action_groups) + + def execute(self) -> list[list[Any]]: + """Execute all action groups sequentially""" + for action_group in self._action_groups: + action_group.send_goals_async() + try: + action_group.wait_for_feedback_triggers() + except Exception: + action_group.wait_for_results() + raise + + return self.wait_for_results() + + def yield_for_execution(self) -> Generator[None, None, list[list[Any]]]: + """Execute all action groups sequentially, yielding periodically""" + try: + for action_group in self._action_groups: + action_group.send_goals_async() + yield from action_group.yield_for_feedback_triggers() + except Exception: + # Block until all actions are finished + action_group.wait_for_results() + raise + + # Wait for all actions to finish executing, and raising potential exceptions + results = [] + for action_group in self._action_groups: + result = yield from action_group.yield_for_results() + results.append(result) + + return results + + def cancel_async(self) -> None: + """Cancel all running actions""" + for action_group in self._action_groups: + action_group.cancel_goals_async() + + def cancel(self) -> None: + """This finishes any ongoing cancellations, and waits for results""" + # Begin cancelling async + for action_group in self._action_groups: + action_group.cancel_goals_async() + + # Block until all actions are cancelled + for action_group in self._action_groups: + action_group.cancel_goals() + + def wait_for_results(self) -> list[list[Any]]: + """Wait until all actions have returned a result or finished cancelling""" + return [action_group.wait_for_results() for action_group in self._action_groups] + + +class ParallelActionSequences: + """Useful when you need to run parallel tracks of ActionSequences. + + For example, let's say you need to run the sequence: + 1. grab nails + 2. go to home position + 3. retract + + But you need to do it for 2 different robots, in parallel. You don't want to + wait for the first robot to finish grabbing nails before the second robot can + start grabbing nails. This is where ParallelActionSequences comes in. + """ + + def __init__( + self, *action_sequences: ActionSequence | Iterable[ActionSequence] + ) -> None: + self._action_sequences = _normalize_iterable(*action_sequences) + + def yield_for_execution(self) -> Generator[None, None, list[list[list[Any]]]]: + """Execute all action sequences in parallel, yielding + + :yield: None, until all action sequences are finished + :return: A list of where + - Each element in the list is the result of an ActionSequence + - Each element in the inner list is the results of actions of an ActionGroup + - Each element in the inner-inner list is a ROS action Result message + :raises Exception: If any of the actions raise exceptions, the first will be + re-raised here. + """ + + # Get generators for each action sequence + sequence_generators = [ + seq.yield_for_execution() for seq in self._action_sequences + ] + + # Keep track of whether each generator is done + done = [False] * len(sequence_generators) + results = [None] * len(sequence_generators) + collector = ExceptionCollector() + while not all(done): + for seq_idx, gen in enumerate(sequence_generators): + if not done[seq_idx]: + with collector: + try: + next(gen) + except StopIteration as e: + results[seq_idx] = e.value + done[seq_idx] = True + except Exception: + done[seq_idx] = True + raise + + # Yield inbetween checks of the underlying generators + yield + + collector.maybe_raise() + + return results # type: ignore + + def cancel_async(self) -> None: + """Cancel all running actions""" + for action_sequence in self._action_sequences: + action_sequence.cancel_async() + + def cancel(self) -> None: + """This finishes any ongoing cancellations, and waits for results""" + self.cancel_async() + + for action_sequence in self._action_sequences: + action_sequence.cancel() diff --git a/pkgs/node_helpers/node_helpers/actions/context_manager.py b/pkgs/node_helpers/node_helpers/actions/context_manager.py new file mode 100644 index 0000000..9ee7630 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/context_manager.py @@ -0,0 +1,143 @@ +from collections.abc import Callable +from contextlib import AbstractContextManager +from threading import Event +from types import TracebackType +from typing import Any, cast + +from rclpy import Future +from rclpy.action.client import ClientGoalHandle + +from node_helpers.futures import wait_for_future +from node_helpers.robust_rpc import RobustActionClient + + +class ActionContextManager: + """A generic wrapper that turns specifically designed, long-running actions into + context managers. + + This is intended for use with actions that start some operation, produce a feedback + message indicating that the operation has started, and then run indefinitely until + they are cancelled. + + By default, ActionContextManager synchronously waits for feedback on enter and waits + for cancellation on exit. + + >>> with ActionContextManager(client, goal) as mgr: + >>> # Server provided feedback by this point + >>> # Do some work + + In async mode, the context manager does not wait for feedback when entered and + asynchronously cancels the goal when exited. At any time while the context manager + is active, the user can choose to wait for feedback to be received. After the + context manager has exited, the user can wait for cancellation to finish. + + >>> with ActionContextManager(client, goal, async_=True) as mgr: + >>> # Do some work + >>> mgr.wait_for_feedback() + >>> # Do some more work + >>> + >>> mgr.wait_for_cancellation() + """ + + def __init__( + self, + client: RobustActionClient, + goal: Any, + timeout: float | None = None, + async_: bool = False, + on_feedback: Callable[[Any], None] | None = None, + ): + """ + :param client: A client for the action to call + Make sure your client has feedback_sub_qos_profile set to + qos_profile_services_default, or else it might not robustly receive feedback + :param goal: The goal to send through the client + :param async_: Controls asynchronous mode + :param timeout: A timeout to use when calling and cancelling the action. If + None, no timeout is used. + :param on_feedback: An optional callback to call when feedback is received. + Streams the {MESSAGE}_Feedback object as they come in. + """ + self._client = client + self._goal = goal + self._goal_context: AbstractContextManager[ClientGoalHandle] | None = None + self._timeout = timeout + self._got_feedback = Event() + self._result_future: Future | None = None + self._async = async_ + self._on_feedback = on_feedback + + def __enter__(self) -> "ActionContextManager": + # Track when the first feedback is received by triggering 'got_feedback' + def _feedback_wrapper(feedback: Any) -> None: + self._got_feedback.set() + if self._on_feedback is not None: + # Pass on the feedback to the user's callback, if one was provided + self._on_feedback(feedback.feedback) + + self._goal_context = self._client.send_goal_as_context( + self._goal, + feedback_callback=_feedback_wrapper, + timeout=self._timeout, + block_for_result=False, + ) + client_handle = self._goal_context.__enter__() + self._result_future = client_handle.get_result_async() + + if not self._async: + self.wait_for_feedback() + + return self + + def activate(self) -> "ActionContextManager": + return self.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._goal_context is not None: + self._goal_context.__exit__(exc_type, exc_val, exc_tb) + self._goal_context = None + if not self._async: + self.wait_for_cancellation() + + self.check_for_exceptions() + + def deactivate(self) -> None: + if self.activated: + return self.__exit__(None, None, None) + + def check_for_exceptions(self) -> None: + """Check for and raise any remote exceptions""" + if self._result_future is None: + raise RuntimeError("check_for_exceptions called before entering context") + + if self._result_future.done(): + self._result_future.result() + + def wait_for_feedback(self) -> None: + if self._result_future is None: + raise RuntimeError("wait_for_feedback called before entering context") + + while not self._got_feedback.wait(0.01): + # Check if the action finished on its own. This is usually because of an + # exception, but could also happen due to a bug on the server side. + if self._result_future.done(): + self._result_future.result() # Potentially raise an exception + raise RuntimeError("Action finished on its own without being cancelled") + + def wait_for_cancellation(self) -> None: + wait_for_future(self._result_future, object, timeout=self._timeout) + + @property + def done(self) -> bool: + if self._result_future is None: + raise RuntimeError("Done called before entering context") + return cast(bool, self._result_future.done()) + + @property + def activated(self) -> bool: + return self._goal_context is not None diff --git a/pkgs/node_helpers/node_helpers/actions/generators.py b/pkgs/node_helpers/node_helpers/actions/generators.py new file mode 100644 index 0000000..f3224b8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/generators.py @@ -0,0 +1,35 @@ +from collections.abc import Generator +from time import sleep, time + + +def generator_sleep(seconds: float) -> Generator[None, None, None]: + """Like time.sleep, but routinely yields to allow for preemption""" + + start = time() + elapsed = 0.0 + + while elapsed < seconds: + elapsed = time() - start + + if 0 <= seconds - elapsed < _SLEEP_INTERVAL: + # The sleep is too short to bother with preemption + sleep(seconds - elapsed) + + # Yield at least once, in case the sleep was too short and the function + # hasn't yielded yet. + yield None + return + + sleep(_SLEEP_INTERVAL) + yield None + + +def yield_until(delay: float) -> Generator[None, None, None]: + """Yields until the provided delay has passed without sleeping""" + + start = time() + while time() - delay < start: + yield None + + +_SLEEP_INTERVAL = 0.005 diff --git a/pkgs/node_helpers/node_helpers/actions/server/__init__.py b/pkgs/node_helpers/node_helpers/actions/server/__init__.py new file mode 100644 index 0000000..962f8f0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/__init__.py @@ -0,0 +1,5 @@ +from .base_handler import ActionCallMetric, ActionHandler +from .context_action_handler import ContextActionHandler +from .fail_fast_handler import FailFastActionHandler, SynchronousActionCalledInParallel +from .queued_handler import QueuedActionHandler +from .worker import ActionTimeoutError, ActionWorker, NoResultSetError diff --git a/pkgs/node_helpers/node_helpers/actions/server/_typing.py b/pkgs/node_helpers/node_helpers/actions/server/_typing.py new file mode 100644 index 0000000..81eee6f --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/_typing.py @@ -0,0 +1,7 @@ +from typing import TypeVar + +from node_helpers.robust_rpc.typing import ResponseType + +GOAL_TYPE = TypeVar("GOAL_TYPE") +FEEDBACK_TYPE = TypeVar("FEEDBACK_TYPE") +RESULT_TYPE = TypeVar("RESULT_TYPE", bound=ResponseType) diff --git a/pkgs/node_helpers/node_helpers/actions/server/base_handler.py b/pkgs/node_helpers/node_helpers/actions/server/base_handler.py new file mode 100644 index 0000000..daca647 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/base_handler.py @@ -0,0 +1,141 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Generic, Literal +from uuid import UUID + +from action_msgs.msg import GoalStatus +from action_msgs.srv import CancelGoal +from rclpy.action.server import CancelResponse, ServerGoalHandle +from rclpy.callback_groups import CallbackGroup + +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.robust_rpc.typing import RobustActionMsg +from node_helpers.timing import Timer + +from ._typing import FEEDBACK_TYPE, GOAL_TYPE, RESULT_TYPE +from .worker import ActionWorker + + +@dataclass +class ActionCallMetric: + """A metric to be reported on an action, and passed to the metrics_callback""" + + action_name: str + node_namespace: str + goal_id: UUID + + elapsed: float | None = None + error_name: str | None = None + error_description: str | None = None + result: Literal["success", "error", "canceled", "in_progress"] | None = None + + +class ActionHandler(ABC, Generic[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]): + """Creates and runs an ActionWorker instance when an action request is received""" + + def __init__( + self, + node: HelpfulNode, + action_name: str, + action_type: type[RobustActionMsg], + callback_group: CallbackGroup, + cancellable: bool = True, + metrics_callback: Callable[[ActionCallMetric], None] | None = None, + **action_server_kwargs: Any, + ): + """ + :param node: The node to create the action server on + :param action_name: The name of the action + :param action_type: The ROS message type of the action + :param callback_group: The callback group to use for the action server + :param cancellable: Whether the action can be canceled + :param metrics_callback: A callback to call when metrics are reported. This is + useful for testing or for hooking up to a metrics system that records all + action calls to a database. + + This will be called 2 times for each action call: + - When the action is first called, with a result of "in_progress" + - When the action is completed with "success", "error", or "canceled" + + :param action_server_kwargs: Additional keyword arguments to pass to the action + """ + + self._request_timeout: float | None = node.declare_and_get_parameter( + f"{action_name}_action_timeout", + type_=float, + description=f"The timeout, in seconds, to use for the {action_name} action", + default_value=0.0, + ) + + # If the default is not set, don't time out + self._request_timeout = self._request_timeout or None + + self.node = node + self.action_name = action_name + self._callback_group = callback_group + + self._metrics_callback = metrics_callback or (lambda _: None) + + def cancel_callback(_request: CancelGoal.Request) -> CancelResponse: + return CancelResponse.ACCEPT if cancellable else CancelResponse.REJECT + + self._server = node.create_robust_action_server( + action_type=action_type, + action_name=action_name, + execute_callback=self._on_goal_with_report, + callback_group=self._callback_group, + cancel_callback=cancel_callback, + **action_server_kwargs, + ) + + @abstractmethod + def create_worker( + self, goal_handle: ServerGoalHandle + ) -> ActionWorker[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]: + """Called when a worker needs to be made for a new goal""" + + def on_goal(self, goal_handle: ServerGoalHandle) -> RESULT_TYPE: + """Called when a goal request is received. May be overridden by subclasses.""" + worker = self.create_worker(goal_handle) + return worker.execute_callback(self._request_timeout) + + def _on_goal_with_report(self, goal_handle: ServerGoalHandle) -> RESULT_TYPE: + """Wraps _on_goal and publishes a metric reporting information on the action""" + metric = ActionCallMetric( + action_name=self.action_name, + node_namespace=self.node.get_namespace(), + goal_id=UUID(bytes=bytes(goal_handle.goal_id.uuid)), + result="in_progress", + ) + + # Always publish an 'in_progress' report immediately + self._metrics_callback(metric) + + timer = Timer() + try: + with timer: + result = self.on_goal(goal_handle) + + # If there were no errors, that's considered a success! + metric.result = "success" + return result + except Exception as ex: + # Record information on the error, including a description from the error if + # it's available + if isinstance(ex, RobustRPCException): + metric.error_name = ex.error_name + metric.error_description = ex.error_description + else: + metric.error_name = type(ex).__name__ + metric.error_description = str(ex) + metric.result = "error" + raise + finally: + if goal_handle.status == GoalStatus.STATUS_CANCELED: + metric.result = "canceled" + + # Regardless of the result status, always record the final time + metric.elapsed = timer.elapsed + self._metrics_callback(metric) diff --git a/pkgs/node_helpers/node_helpers/actions/server/context_action_handler.py b/pkgs/node_helpers/node_helpers/actions/server/context_action_handler.py new file mode 100644 index 0000000..e4519ed --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/context_action_handler.py @@ -0,0 +1,29 @@ +from rclpy.callback_groups import ReentrantCallbackGroup +from rclpy.qos import qos_profile_services_default + +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc.typing import RobustActionMsg + +from ._typing import FEEDBACK_TYPE, GOAL_TYPE, RESULT_TYPE +from .base_handler import ActionHandler + + +class ContextActionHandler(ActionHandler[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]): + """This is a preconfigured action handler for making context handler actions. + Specifically, it ensures that the feedback_pub_qos_profile is set to + services_default, and that the action is cancellable.""" + + def __init__( + self, + node: HelpfulNode, + action_name: str, + action_type: type[RobustActionMsg], + ): + super().__init__( + node=node, + action_name=action_name, + action_type=action_type, + callback_group=ReentrantCallbackGroup(), + cancellable=True, + feedback_pub_qos_profile=qos_profile_services_default, + ) diff --git a/pkgs/node_helpers/node_helpers/actions/server/fail_fast_handler.py b/pkgs/node_helpers/node_helpers/actions/server/fail_fast_handler.py new file mode 100644 index 0000000..6c314e2 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/fail_fast_handler.py @@ -0,0 +1,57 @@ +from threading import RLock +from typing import Any + +from rclpy.action.server import ServerGoalHandle +from rclpy.callback_groups import ReentrantCallbackGroup + +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc.typing import RobustActionMsg + +from ._typing import FEEDBACK_TYPE, GOAL_TYPE, RESULT_TYPE +from .base_handler import ActionHandler + + +class SynchronousActionCalledInParallel(Exception): + """This exception is raised when an exception that should only ever be called + synchronously is called in parallel, meaning there's likely a bug somewhere.""" + + +class FailFastActionHandler(ActionHandler[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]): + """This action handler will immediately raise an exception if the action is called + in parallel. This is useful for defining an action that should only ever be called + synchronously, and you want to prevent situations where the code runs in parallel. + """ + + def __init__( + self, + node: HelpfulNode, + action_name: str, + action_type: type[RobustActionMsg], + cancellable: bool = True, + **action_server_kwargs: dict[str, Any], + ): + super().__init__( + node=node, + action_name=action_name, + action_type=action_type, + callback_group=ReentrantCallbackGroup(), + cancellable=cancellable, + metrics_callback=None, + **action_server_kwargs, + ) + + self.action_lock = RLock() + """This lock is held while an action is being run""" + + def on_goal(self, goal_handle: ServerGoalHandle) -> RESULT_TYPE: + acquired = self.action_lock.acquire(blocking=False) + if not acquired: + raise SynchronousActionCalledInParallel( + f"The action '{self.action_name}' was called while another call was in" + f" progress!" + ) + + try: + return super().on_goal(goal_handle) + finally: + self.action_lock.release() diff --git a/pkgs/node_helpers/node_helpers/actions/server/queued_handler.py b/pkgs/node_helpers/node_helpers/actions/server/queued_handler.py new file mode 100644 index 0000000..55d18b0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/queued_handler.py @@ -0,0 +1,88 @@ +import queue +from threading import Event + +from rclpy.action.server import ServerGoalHandle +from rclpy.callback_groups import ReentrantCallbackGroup + +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc.typing import RobustActionMsg + +from ._typing import FEEDBACK_TYPE, GOAL_TYPE, RESULT_TYPE +from .base_handler import ActionHandler +from .worker import ActionWorker + + +class QueuedActionHandler(ActionHandler[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]): + """Gives requests to ActionWorkers one-by-one, ensuring the action is not run more + than once at a time + """ + + def __init__( + self, + node: HelpfulNode, + action_name: str, + action_type: type[RobustActionMsg], + cancellable: bool = True, + ): + super().__init__( + node=node, + action_name=action_name, + action_type=action_type, + callback_group=ReentrantCallbackGroup(), + cancellable=cancellable, + ) + + self._current_worker: ( + None | (ActionWorker[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]) + ) = None + + # A queue of workers waiting to run, and an event used to signal when it is + # time for that worker to run + self._worker_queue: "queue.Queue[tuple[ActionWorker[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE], Event]]" = ( # noqa: E501 + queue.Queue() + ) + + self._find_work_timer = node.create_timer( + 0.05, self._find_work, callback_group=self._callback_group + ) + + def is_working(self) -> bool: + """ + :return: If the handler is currently processing a request + """ + return self._current_worker is not None + + def on_goal(self, goal_handle: ServerGoalHandle) -> RESULT_TYPE: + worker = self.create_worker(goal_handle) + + queued_actions = self._worker_queue.qsize() + if queued_actions > 0: + self.node.get_logger().warning( + f"{queued_actions} actions waiting on the job queue for " + f"{self.action_name}. Requests may be delayed" + ) + + start_working = Event() + + self._worker_queue.put((worker, start_working)) + + # Wait until it's this worker's turn to run + start_working.wait() + + return worker.execute_callback(self._request_timeout) + + def _find_work(self) -> None: + """Looks for new requests in the action queue, if no action is currently being + run + """ + + if self._current_worker is not None and self._current_worker.done: + self._current_worker = None + + if self._current_worker is None: + try: + worker, start_executing = self._worker_queue.get_nowait() + self._current_worker = worker + start_executing.set() + except queue.Empty: + pass diff --git a/pkgs/node_helpers/node_helpers/actions/server/worker.py b/pkgs/node_helpers/node_helpers/actions/server/worker.py new file mode 100644 index 0000000..005e559 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/actions/server/worker.py @@ -0,0 +1,116 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Generic, cast + +from rclpy.action.server import ServerGoalHandle + +from node_helpers.timing import Timeout + +from ._typing import FEEDBACK_TYPE, GOAL_TYPE, RESULT_TYPE + + +class ActionTimeoutError(Exception): + """The action took longer than the timeout to run""" + + +class NoResultSetError(Exception): + """The action implementation completed successfully, but no result value was set""" + + +class ActionWorker(ABC, Generic[GOAL_TYPE, FEEDBACK_TYPE, RESULT_TYPE]): + def __init__(self, goal_handle: ServerGoalHandle): + self.goal_handle = goal_handle + """A handle used to communicate with the caller. Implementations usually do not + need to use this object directly. See ``run`` for details. + """ + + self.goal = cast(GOAL_TYPE, goal_handle.request) + """The goal data sent by the caller, containing information on the request""" + + self.done = False + """True if the worker is no longer running, including as a result of an error + or cancellation + """ + + self.result: RESULT_TYPE | None = None + """The result value to provide to the caller. This must be set by ``run`` if + the work completed successfully. + """ + + @abstractmethod + def run(self) -> Generator[FEEDBACK_TYPE | None, None, None]: + """Triggered when an action begins execution. This function is expected to + regularly yield (feedback objects or None) and set the ``result`` attribute + before returning. + + This method will be automatically interrupted after a yield if a timeout is + reached or if the action is canceled. This means that, for best operation, + implementations of this method should yield regularly. + """ + + @abstractmethod + def on_cancel(self) -> RESULT_TYPE: + """Triggered when the client cancels the request + + :return: The result to provide to the client + """ + + def on_exception(self, ex: Exception) -> None: + """Triggered when an exception happens in the ``run`` method""" + + def execute_callback(self, timeout: float | None) -> RESULT_TYPE: + """Runs the action, blocking until it's finished + + :param timeout: The maximum allowed time for the action to take, or None for + no limit + :return: The result to provide to the client + :raises ActionTimeoutError: If the action is taking longer than the provided + timeout + :raises NoResultSetError: If the ``run`` method succeeds, but does not set the + result attribute + :raises RuntimeError: If something strange happens + """ + + try: + run_generator = self.run() + + timeout_obj: bool | Timeout = ( + Timeout(timeout) if timeout is not None else True + ) + + while timeout_obj and not self.goal_handle.is_cancel_requested: + try: + feedback = next(run_generator) + if feedback is not None: + self.goal_handle.publish_feedback(feedback) + except StopIteration as e: + if self.result is None: + raise NoResultSetError( + "Action failed to produce a result" + ) from e + self.goal_handle.succeed() + return self.result + except Exception as ex: + self.on_exception(ex) + raise + + if self.goal_handle.is_cancel_requested: + logging.debug(f"Canceling action. Goal={self.goal}") + + # Forces immediate garbage collection on the generator, ensuring that + # context managers and `finally` blocks are run before cancellation. + del run_generator + + result = self.on_cancel() + self.goal_handle.canceled() + return result + elif not timeout_obj: + raise ActionTimeoutError("Timeout while running the action") + + raise RuntimeError( + "Action stopped running before a cancel request or timeout! This " + "should never happen!" + ) + finally: + self.done = True diff --git a/pkgs/node_helpers/node_helpers/async_tools/README.md b/pkgs/node_helpers/node_helpers/async_tools/README.md new file mode 100644 index 0000000..45951b8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/async_tools/README.md @@ -0,0 +1,8 @@ +# node_helpers.async_tools + +The node_helpers.async_tools module provides a utility class, AsyncAdapter, to bridge the gap between traditional ROS 2 nodes and asyncio-based asynchronous programming. This is particularly useful for integrating modern Python asynchronous networking and workflows into ROS 2 systems, which predominantly rely on callback-driven and multithreaded architectures. + +## Key Features +**Background Event Loop**: Creates and manages a dedicated asyncio event loop that runs independently of the ROS event loop. +**Async Callback Adaptation**: Converts asyncio coroutines into functions compatible with ROS callback mechanisms, enabling seamless integration of async workflows. +**Clean Shutdown**: Ensures proper cleanup of the asyncio event loop during node destruction. \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/async_tools/__init__.py b/pkgs/node_helpers/node_helpers/async_tools/__init__.py new file mode 100644 index 0000000..5e85824 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/async_tools/__init__.py @@ -0,0 +1 @@ +from .async_adapter import AsyncAdapter diff --git a/pkgs/node_helpers/node_helpers/async_tools/async_adapter.py b/pkgs/node_helpers/node_helpers/async_tools/async_adapter.py new file mode 100644 index 0000000..ad12742 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/async_tools/async_adapter.py @@ -0,0 +1,85 @@ +import asyncio +import functools +from collections.abc import Callable, Coroutine +from concurrent.futures import Future as ThreadingFuture +from threading import Event as ThreadingEvent +from typing import Any, TypeVar + +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup + +from node_helpers.nodes import HelpfulNode + +_CORO_RETURN = TypeVar("_CORO_RETURN") + + +class AsyncAdapter: + """Provides a bridge between asyncio code and a traditional ROS node with a + multithreaded executor. This takes care of creating an event loop that runs in the + background, separate from ROS's event loop. + + Rclpy actually does have some native support for using coroutines as callbacks, but + that support is not well-documented and has some limitations. Namely, the coroutines + are run directly as generators instead of through an asyncio event loop. This + (I think) means that asyncio networking won't work, which we need. + + In the future, we could consider running the Rclpy event loop inside an asyncio + event loop. + """ + + def __init__(self, node: HelpfulNode): + self.event_loop = asyncio.get_event_loop() + + node.create_single_shot_timer( + 0.0, + self._run_event_loop, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + node.on_destroy(self._stop_event_loop) + + def adapt( + self, callback: Callable[..., Coroutine[Any, Any, _CORO_RETURN]] + ) -> Callable[..., _CORO_RETURN]: + """Adapts an async function to be run in the event loop of the node when called. + The resulting function will block until the coroutine has completed. This is + intended for use with topic and service subscriber callbacks. + + :param callback: The async function to wrap + :return: The return value of the coroutine once it's completed + """ + + @functools.wraps(callback) + def wrapper(*args: Any, **kwargs: Any) -> _CORO_RETURN: + future: ThreadingFuture[_CORO_RETURN] = ThreadingFuture() + + def run_task() -> None: + new_task = self.event_loop.create_task(callback(*args, **kwargs)) + + def on_done(task: asyncio.Task[_CORO_RETURN]) -> None: + if task.exception(): + future.set_exception(task.exception()) + else: + future.set_result(task.result()) + + new_task.add_done_callback(on_done) + + self.event_loop.call_soon_threadsafe(run_task) + return future.result() + + return wrapper + + def _run_event_loop(self) -> None: + try: + self.event_loop.run_forever() + finally: + self.event_loop.close() + + def _stop_event_loop(self) -> None: + stopped = ThreadingEvent() + + def stop() -> None: + self.event_loop.stop() + stopped.set() + + self.event_loop.call_soon_threadsafe(stop) + if not stopped.wait(5.0): + raise RuntimeError("Timeout while stopping asyncio event loop") diff --git a/pkgs/node_helpers/node_helpers/destruction/README.md b/pkgs/node_helpers/node_helpers/destruction/README.md new file mode 100644 index 0000000..32b9939 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/destruction/README.md @@ -0,0 +1,6 @@ +# node_helpers.destruction + +This module provides a utility mixin, DestroyCallbacksMixin, to enhance ROS 2 nodes with a flexible destruction lifecycle. It enables developers to register custom callbacks that are invoked when the node is destroyed, ensuring proper resource cleanup and lifecycle management. + +This mixin is particularly useful for managing resources like threads, file handles, or other external dependencies that need explicit teardown when a node is shut down. + diff --git a/pkgs/node_helpers/node_helpers/destruction/__init__.py b/pkgs/node_helpers/node_helpers/destruction/__init__.py new file mode 100644 index 0000000..ffc6075 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/destruction/__init__.py @@ -0,0 +1 @@ +from .mixin import DestroyCallbacksMixin diff --git a/pkgs/node_helpers/node_helpers/destruction/mixin.py b/pkgs/node_helpers/node_helpers/destruction/mixin.py new file mode 100644 index 0000000..c017c95 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/destruction/mixin.py @@ -0,0 +1,34 @@ +from collections.abc import Callable +from typing import Any, cast + +from rclpy.node import Node + + +class DestroyCallbacksMixin: + _destroy_callbacks: list[Callable[[], Any]] + + def on_destroy(self, callback: Callable[[], Any]) -> None: + """Registers a callback that the node will call when destroyed. Useful for + cleaning up resources. Multiple callbacks can be registered. + + :param callback: The function to call + """ + self._get_destroy_callbacks().append(callback) + + def destroy_node(self) -> None: + """Overrides Node.destroy_node() to call registered callbacks""" + for callback in self._get_destroy_callbacks(): + callback() + cast(Node, super()).destroy_node() + + def _get_destroy_callbacks(self) -> list[Callable[[], Any]]: + """Lazily creates the attribute that holds destroy callbacks. We do this to + avoid having to deal with __init__ methods when using multiple inheritance. + + :return: The list of callbacks + """ + try: + return self._destroy_callbacks + except AttributeError: + self._destroy_callbacks = [] + return self._destroy_callbacks diff --git a/pkgs/node_helpers/node_helpers/futures/README.md b/pkgs/node_helpers/node_helpers/futures/README.md new file mode 100644 index 0000000..ec19054 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/futures/README.md @@ -0,0 +1,19 @@ +# node_helpers.futures + +This module provides utility functions to simplify working with `rclpy.Future` objects in ROS 2. These tools make it easier to manage asynchronous tasks, including waiting for completion, handling timeouts, preemption, and executing actions with a clean, synchronous-like interface. + +The module is especially useful for scenarios where you need robust handling of ROS futures in a multi-threaded environment or want to integrate asynchronous workflows without adding complexity. + + +For example: +```python3 + +from node_helpers import futures + +future = node.call_some_action() +result = futures.wait_for_future(future, ExpectedType, timeout=5.0) + +# Or with multiple futures +results = wait_for_futures([future1, future2], ExpectedType, timeout=10.0) + +``` \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/futures/__init__.py b/pkgs/node_helpers/node_helpers/futures/__init__.py new file mode 100644 index 0000000..52b0839 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/futures/__init__.py @@ -0,0 +1,200 @@ +from collections.abc import Callable, Generator +from threading import Event +from typing import Any, TypeVar, cast +from unittest.mock import Mock + +from action_msgs.msg import GoalStatus +from rclpy import Future +from rclpy.action import ActionClient +from rclpy.action.client import ClientGoalHandle + +from .exceptions import ExceptionCollector + +T = TypeVar("T") + + +def wait_for_future(future: Future, type_: type[T], timeout: float | None = None) -> T: + result_ready = Event() + future.add_done_callback(lambda f: result_ready.set()) + + future_done: bool = result_ready.wait(timeout=timeout) + if not future_done: + raise TimeoutError("Timeout while waiting for the future to finish!") + + # Exceptions may be raised from result() + result = future.result() + + if isinstance(result, Mock): + # Allow mocks without question to make testing easier + return cast(T, result) + else: + assert isinstance( + result, type_ + ), f"Expected result to be of type {type_} but instead got {type(result)}" + return result + + +def wait_for_preemptible_future( + goal_handle: ClientGoalHandle, + preempt_condition: Callable[[], bool], + yield_interval: float = 0.01, +) -> bool: + """ + Run an action asynchronously, allowing for cancellation if a condition is met. + :param goal_handle: The handle to cancel if necessary + :param preempt_condition: if true, cancel the goal + :param yield_interval: How often to yield inbetween waiting for the future. + :returns: True if the clamps were successfully moved, False otherwise + """ + on_wood_move_done = Event() + future = goal_handle.get_result_async() + future.add_done_callback(lambda f: on_wood_move_done.set()) + while not on_wood_move_done.wait(yield_interval): + status = goal_handle.status + + if ( + preempt_condition() + and not future.done() + and status != GoalStatus.STATUS_CANCELED + and status != GoalStatus.STATUS_CANCELING + ): + goal_handle.cancel_goal() + goal_handle.get_result() + return False + return True + + +def wait_for_futures( + futures: list[Future], type_: type[T], timeout: float | None = None +) -> list[T]: + """ + Block until all futures are complete. If an exception occurs, the first is raised + """ + + results: list[T] = [] + collector = ExceptionCollector() + for future in futures: + with collector: + results.append(wait_for_future(future, type_, timeout)) + + collector.maybe_raise() + + return results + + +def wait_for_send_goal( + client: ActionClient, + goal: Any, + feedback_callback: Callable[[Any], None] | None = None, +) -> tuple[Future, ClientGoalHandle]: + """A shorthand way to asynchronously call an action and get a get_result future""" + future = client.send_goal_async(goal, feedback_callback=feedback_callback) + goal_handle = wait_for_future(future, ClientGoalHandle) + return goal_handle.get_result_async(), goal_handle + + +def yield_for_futures( + futures: list[Future], type_: type[T], yield_interval: float +) -> Generator[None, None, list[T]]: + """ + Yield until all futures are complete. If an exception occurs, the first is raised + """ + + results: list[T] = [] + collector = ExceptionCollector() + for future in futures: + with collector: + result = yield from yield_for_future( + future, type_, yield_interval=yield_interval + ) + results.append(result) + + collector.maybe_raise() + + return results + + +def yield_for_future( + future: Future, type_: type[T], yield_interval: float +) -> Generator[None, None, T]: + """This function will yield `None` until the future is complete, then return results + + This behavior is useful for waiting for actions to finish while allowing + cancellation in an ActionWorker. + + Usage Example, in an action worker "run" loop. In this example, cancellation will be + checked every 0.1 seconds, and when the future is finished it will be stored in + 'results'. + >>> future = client.call_async() + >>> results = yield from yield_for_future(future, yield_interval=0.1) + + :param future: The future to wait for + :param type_: The type of the result that will be returned + :param yield_interval: How often to yield inbetween waiting for the future. + :yields: None, until all futures are complete + :return: The resulting value of the future. + """ + + result_ready = Event() + future.add_done_callback(lambda f: result_ready.set()) + + while not result_ready.wait(timeout=yield_interval): + yield None + + # Exceptions may be raised from result() + result = future.result() + + assert isinstance( + result, type_ + ), f"Expected result to be of type {type_} but instead got {type(result)}" + return result + + +def run_action_with_timeout( + client: ActionClient, goal: Any, response_type: type[T], timeout: float = 60.0 +) -> T: + """Run an action and block for the response, raising a TimeoutError if it takes too + long. This action is intended for test usage, since in production it's typically a + bad idea to use Timeouts for actions that might still be doing things. + :param client: The action client + :param goal: The goal to send + :param response_type: The expected result type. This can be weird, it's an + autogenerated object of the name {ActionName}_GetResult_Response + :param timeout: The number of seconds + :return: The result of the action + :raises RuntimeError: If the server shuts down before the action completes + """ + + try: + handle_future: Future = client.send_goal_async(goal) + + result_future = wait_for_future( + handle_future, ClientGoalHandle, timeout=timeout + ).get_result_async() + except TimeoutError as ex: + if not client.server_is_ready(): + msg = "The action server shut down before the action completed!" + raise RuntimeError(msg) from ex + raise + return wait_for_future(result_future, response_type, timeout=timeout) + + +def split_done_futures(futures: list[Future]) -> tuple[list[Future], list[Future]]: + """Given a list of futures, categorize them into 'done' and unfinished futures + :param futures: The futures to categorize + :returns: (done futures, unfinished futures) + """ + finished: list[Future] = [] + unfinished: list[Future] = [] + for future in futures: + category = finished if future.done() else unfinished + category.append(future) + return finished, unfinished + + +__all__ = [ + "wait_for_future", + "run_action_with_timeout", + "split_done_futures", + "wait_for_send_goal", +] diff --git a/pkgs/node_helpers/node_helpers/futures/exceptions.py b/pkgs/node_helpers/node_helpers/futures/exceptions.py new file mode 100644 index 0000000..f135161 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/futures/exceptions.py @@ -0,0 +1,39 @@ +import logging +import traceback +from types import TracebackType + + +class ExceptionCollector: + """A helper class for collecting exceptions in cases where multiple operations must + finish before an exception can be re-raised. + + In the future, a GroupException may be implemented using python 3.12 features. + Currently, maybe_raise() raises the first collected exception. + """ + + def __init__(self) -> None: + self.exceptions: list[tuple[BaseException, TracebackType | None]] = [] + + def __enter__(self) -> None: + pass + + def __exit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> bool: + if _exc_val: + self.exceptions.append((_exc_val, _exc_tb)) + return True # Suppress the exception + + def had_exceptions(self) -> bool: + return bool(self.exceptions) + + def maybe_raise(self, log: bool = True) -> None: + if self.exceptions: + if log: + for exc, tb in self.exceptions: + logging.error(f"Exception occurred: {exc}") + logging.error("Traceback:\n" + "".join(traceback.format_tb(tb))) + raise self.exceptions[0][0] diff --git a/pkgs/node_helpers/node_helpers/interaction/README.md b/pkgs/node_helpers/node_helpers/interaction/README.md new file mode 100644 index 0000000..1d2f48b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/README.md @@ -0,0 +1,23 @@ +# node_helpers.interaction + +This module provides a framework for creating interactive menu systems within ROS 2 applications. It is designed to simplify the process of generating dynamic menus and prompts that allow users to make selections, control actions, and provide input during runtime. + + + +## Module Components +###Base Classes + +### BaseMenu + +A foundational class for creating interactive menus. It provides mechanisms to display options, capture user input, and link selections to specific callbacks. It also supports running cancelable actions, making it suitable for scenarios where operations need to be interrupted based on user decisions. + +### BasePrompter +An abstract base class that handles the logic of displaying prompts and waiting for user input. It defines methods for connecting to a ROS system, interpreting user messages, and managing subscriptions or services for input. + +### Implementations + +#### DashboardMenu +Extends the functionality of the base menu by integrating with a web-based dashboard. It uses ROS topics and services to display menus and capture responses, making it ideal for remote user interfaces. + +#### DashboardPrompter +Implements the prompter functionality for the web UI. It allows users to make selections through a dashboard and manages the lifecycle of prompts and responses. \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/interaction/__init__.py b/pkgs/node_helpers/node_helpers/interaction/__init__.py new file mode 100644 index 0000000..e637d0d --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/__init__.py @@ -0,0 +1,8 @@ +from .menus import ( + BaseMenu, + BasicPromptMetadata, + DashboardMenu, + PromptMetadata, + TeleopPromptMetadata, +) +from .prompting import * # noqa: F403 diff --git a/pkgs/node_helpers/node_helpers/interaction/menus/__init__.py b/pkgs/node_helpers/node_helpers/interaction/menus/__init__.py new file mode 100644 index 0000000..0ea1948 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/menus/__init__.py @@ -0,0 +1,3 @@ +from .base_menu import BaseMenu +from .dashboard import DashboardMenu +from .prompt_metadata import BasicPromptMetadata, PromptMetadata, TeleopPromptMetadata diff --git a/pkgs/node_helpers/node_helpers/interaction/menus/base_menu.py b/pkgs/node_helpers/node_helpers/interaction/menus/base_menu.py new file mode 100644 index 0000000..1502624 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/menus/base_menu.py @@ -0,0 +1,156 @@ +from collections.abc import Callable +from concurrent.futures import CancelledError, Future +from functools import partial +from queue import Queue +from threading import Event +from typing import Any, Generic, TypeVar + +from action_msgs.msg import GoalStatus +from node_helpers_msgs.msg import PromptOption, UserPrompt +from rclpy.action.client import ClientGoalHandle + +from node_helpers.robust_rpc import RobustActionClient + +from ..prompting import BasePrompter +from .prompt_metadata import BasicPromptMetadata, PromptMetadata + +MESSAGE = TypeVar("MESSAGE") +"""What ros message to look for""" + +PROMPT_PUBLISHER = Callable[[UserPrompt], None] +MENU_ITEMS = tuple[PromptOption, Callable[[], None]] +DEFAULT_CANCEL_OPTION = PromptOption(name="Cancel", description="❌") + + +class BaseMenu(Generic[MESSAGE]): + """This class can take in any kind of Prompter and turn it into a menu""" + + def __init__( + self, prompter: BasePrompter[MESSAGE], publish_prompt: PROMPT_PUBLISHER + ): + """ + :param prompter: The type of prompter to use for this menu + :param publish_prompt: A callable to publish a prompt to the user, so they know + what buttons can be pressed + """ + + self.publish_prompt = publish_prompt + self._prompter = prompter + + def display_menu_async( + self, + *menu_items: MENU_ITEMS, + help_message: str = "", + metadata: PromptMetadata | None = None, + ) -> "Future[Callable[[], None]]": + metadata = metadata or BasicPromptMetadata() + + prompt_options = [o for o, _ in menu_items] + + # Show the options to the user + self.publish_prompt( + UserPrompt( + options=prompt_options, + help=help_message, + type=metadata.type(), + metadata=metadata.model_dump_json(), + ) + ) + + choice_future = self._prompter.choose_async(choices=menu_items) + + # Add a callback that will call the chosen callable when the future finishes + def on_done(future: "Future[Callable[[], None]]") -> None: + try: + future.result()() + except CancelledError: + # If the future was cancelled, no need to call a callback + pass + + choice_future.add_done_callback(on_done) + return choice_future + + def display_menu( + self, + *menu_items: MENU_ITEMS, + help_message: str = "", + metadata: PromptMetadata | None = None, + ) -> None: + self.display_menu_async( + *menu_items, help_message=help_message, metadata=metadata + ).result() + + def run_user_cancellable_action( + self, + action: RobustActionClient, + goal: Any, + help_message: str = "", + menu_items: tuple[MENU_ITEMS, ...] | None = None, + metadata: PromptMetadata | None = None, + ) -> tuple[Any, bool]: + """Run an action while also checking for user input to see if the user wants + the action cancelled. If no menu items are specified, a single option will be + created to allow cancelling the action. if menu items are specified, then + any of those menu items will cancel the action, but will also call their + specified callbacks. + + :param action: The action to run + :param goal: The goal to give the action + :param help_message: The help message to display while the action runs + :param menu_items: An optional list of menu items. All menu items when selected + will cancel the current action, but will also call their specified callback. + :param metadata: Specifies additional data for the prompt, enabling + special features + :returns: A tuple of (action result, bool "cancelled") + """ + + if menu_items is None: + # If no menu items are set, default to having a simple "Cancel action" item. + menu_items = ((DEFAULT_CANCEL_OPTION, lambda: None),) + + # Create an event that will be set either by the prompter or the action + finished_or_canceled = Event() + + def wrap_callback(callback: Callable[[], None]) -> Callable[[], None]: + """Wrap a callback to also set the finished_or_cancelled event""" + + def wrapper() -> None: + finished_or_canceled.set() + return callback() + + return wrapper + + handle: ClientGoalHandle + with action.send_goal_as_context(goal=goal) as handle: + # Augment each menu item to also set the 'finished_or_cancelled' event + menu_items_that_also_cancel_action = [ + (text, wrap_callback(callback)) for text, callback in menu_items + ] + + menu_future = self.display_menu_async( + *menu_items_that_also_cancel_action, + help_message=help_message, + metadata=metadata, + ) + handle.get_result_async().add_done_callback( + lambda *args: finished_or_canceled.set() + ) + + # Now wait for the user to either cancel the request, or it to finish + finished_or_canceled.wait() + self.publish_prompt(UserPrompt(help="Finishing action...")) + + result = handle.get_result() + canceled = menu_future.done() or result.status == GoalStatus.STATUS_CANCELED + + menu_future.cancel() + return result, canceled + + def ask_yes_no(self, question: str, yes: str = "Yes", no: str = "No") -> bool: + response: "Queue[bool]" = Queue() + self.display_menu( + (PromptOption(name=yes, description="👍"), partial(response.put, True)), + (PromptOption(name=no, description="👎"), partial(response.put, False)), + help_message=question, + ) + return response.get_nowait() diff --git a/pkgs/node_helpers/node_helpers/interaction/menus/dashboard.py b/pkgs/node_helpers/node_helpers/interaction/menus/dashboard.py new file mode 100644 index 0000000..a7958fb --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/menus/dashboard.py @@ -0,0 +1,54 @@ +from collections.abc import Generator +from contextlib import contextmanager + +from node_helpers_msgs.msg import PromptOption, UserPrompt + +from node_helpers.nodes import HelpfulNode +from node_helpers.topics import LatchingPublisher + +from ..prompting import DashboardPrompter +from .base_menu import BaseMenu + + +class DashboardMenu(BaseMenu[PromptOption]): + """A menu controlled by buttons on a web UI""" + + DEFAULT_PROMPT_TOPIC = "/dashboard/prompts" + DEFAULT_OPTION_SERVICE = "/dashboard/choose_option" + + def __init__(self, node: HelpfulNode): + self._latching_prompt_publisher = LatchingPublisher( + node, UserPrompt, topic=self.DEFAULT_PROMPT_TOPIC + ) + self._dashboard_prompter = DashboardPrompter( + node=node, srv_name=self.DEFAULT_OPTION_SERVICE + ) + + super().__init__( + prompter=self._dashboard_prompter, + publish_prompt=self._latching_prompt_publisher, + ) + + @contextmanager + def connected(self) -> Generator["DashboardMenu", None, None]: + self.connect() + try: + yield self + finally: + self.disconnect() + + @contextmanager + def disconnected(self) -> Generator[None, None, None]: + self.disconnect() + try: + yield + finally: + self.connect() + + def connect(self) -> None: + self._dashboard_prompter.connect() + + def disconnect(self) -> None: + """Relinquish control of the prompt topic, so another node can use it""" + self._latching_prompt_publisher.clear_msg_state() + self._dashboard_prompter.disconnect() diff --git a/pkgs/node_helpers/node_helpers/interaction/menus/prompt_metadata.py b/pkgs/node_helpers/node_helpers/interaction/menus/prompt_metadata.py new file mode 100644 index 0000000..2e08d94 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/menus/prompt_metadata.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import cast + +from node_helpers_msgs.msg import UserPrompt +from pydantic import BaseModel + + +class PromptMetadata(ABC, BaseModel): + """Defines the metadata field's schema for a UserPrompt type. See the + message definition for UserPrompt for more information. + """ + + @staticmethod + @abstractmethod + def type() -> int: + """The prompt type that this metadata is used with""" + + +class BasicPromptMetadata(PromptMetadata): + @staticmethod + def type() -> int: + return cast(int, UserPrompt.PROMPT_BASIC) + + +class TeleopPromptMetadata(PromptMetadata): + namespaces: list[str] + + @staticmethod + def type() -> int: + return cast(int, UserPrompt.PROMPT_TELEOP) diff --git a/pkgs/node_helpers/node_helpers/interaction/prompting/__init__.py b/pkgs/node_helpers/node_helpers/interaction/prompting/__init__.py new file mode 100644 index 0000000..ee85956 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/prompting/__init__.py @@ -0,0 +1,4 @@ +from .base_prompter import BasePrompter +from .dashboard_prompter import DashboardPrompter + +__all__ = ["BasePrompter", "DashboardPrompter"] diff --git a/pkgs/node_helpers/node_helpers/interaction/prompting/base_prompter.py b/pkgs/node_helpers/node_helpers/interaction/prompting/base_prompter.py new file mode 100644 index 0000000..a9365be --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/prompting/base_prompter.py @@ -0,0 +1,54 @@ +import logging +from abc import abstractmethod +from concurrent.futures import Future +from typing import Generic, TypeVar + +from rclpy.node import Node + +MESSAGE = TypeVar("MESSAGE") +"""A ROS message to look for, when selecting prompt options""" + +Choice = TypeVar("Choice", bound=object) + + +class BasePrompter(Generic[MESSAGE]): + def __init__(self, node: Node): + # Insert a callback that stops waiting on a breakpoint during closing time + node.context.on_shutdown(self.disconnect) + + @abstractmethod + def _add_request_for_message( + self, choices: tuple[tuple[MESSAGE, Choice], ...] + ) -> "Future[Choice]": + """Return a future that will result in the choice a human chose""" + + @abstractmethod + def describe_message(self, message: MESSAGE) -> str: + """Takes a message and returns a human readable description""" + + @abstractmethod + def connect(self) -> None: + """Set up any ROS subscriptions, services, etc""" + + @abstractmethod + def disconnect(self) -> None: + """Clear the subscription and end the current future (if there is any) + + Expectations: + 1) Set any ongoing future(s) to a RuntimeError + 2) Destroy any ROS subscriptions, services, etc + 3) Ensure no new futures can be created + """ + + def choose_async( + self, choices: tuple[tuple[MESSAGE, Choice], ...] + ) -> "Future[Choice]": + msg = "\n\t".join(f"{self.describe_message(j)}: {c}" for j, c in choices) + logging.warning(f"BREAKPOINT- Choose an option:\n\t{msg}") + return self._add_request_for_message(choices) + + def choose( + self, choices: tuple[tuple[MESSAGE, Choice], ...], timeout: float | None = None + ) -> Choice: + """Prompt the user to pick an option""" + return self.choose_async(choices).result(timeout=timeout) diff --git a/pkgs/node_helpers/node_helpers/interaction/prompting/dashboard_prompter.py b/pkgs/node_helpers/node_helpers/interaction/prompting/dashboard_prompter.py new file mode 100644 index 0000000..f541ebc --- /dev/null +++ b/pkgs/node_helpers/node_helpers/interaction/prompting/dashboard_prompter.py @@ -0,0 +1,105 @@ +import logging +from concurrent.futures import Future, InvalidStateError +from dataclasses import dataclass, field +from threading import RLock +from typing import Any + +from node_helpers_msgs.msg import PromptOption +from node_helpers_msgs.srv import ChoosePromptOption +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup +from rclpy.service import Service + +from node_helpers.nodes import HelpfulNode + +from .base_prompter import BasePrompter, Choice + + +class InvalidPromptError(Exception): + pass + + +@dataclass +class _Request: + choices: dict[str, object] + """Store the PromptOption.name -> Choice relationship""" + + future: Future[Any] = field(default_factory=Future) + + +class DashboardPrompter(BasePrompter[PromptOption]): + def __init__(self, node: HelpfulNode, srv_name: str): + super().__init__(node) + self._node = node + self._service: Service | None = None + self._srv_name: str = srv_name + + self._request_lock = RLock() + self._ongoing_request: _Request | None = None + + def _on_user_chooses_option( + self, request: ChoosePromptOption.Request, response: ChoosePromptOption.Response + ) -> ChoosePromptOption.Response: + """Called by a service, this means a user has chosen some menu option""" + logging.info( + f"User has selected the option {self.describe_message(request.option)}!" + ) + + option_name = request.option.name + with self._request_lock: + if self._ongoing_request is None: + raise InvalidPromptError( + f"There's no ongoing prompt right now! Received: {option_name}" + ) + + try: + choice = self._ongoing_request.choices[option_name] + except KeyError as ex: + raise InvalidPromptError( + f"The option '{option_name}' is not currently being prompted for!" + ) from ex + else: + self._ongoing_request.future.set_result(choice) + self._ongoing_request = None + return response + + def _add_request_for_message( + self, choices: tuple[tuple[PromptOption, Choice], ...] + ) -> Future[Choice]: + with self._request_lock: + if self._ongoing_request is not None: + # Mark the current future as cancelled + self._ongoing_request.future.cancel() + future: Future[Choice] = Future() + self._ongoing_request = _Request( + choices={o.name: c for o, c in choices}, future=future + ) + return future + + def describe_message(self, message: PromptOption) -> str: + """Detailed descriptions don't make much sense for PromptOptions""" + return f"Option(name='{message.name}')" + + def connect(self) -> None: + """A context for relinquishing control upon exit""" + if self._service is None: + self._service = self._node.create_robust_service( + ChoosePromptOption, + self._srv_name, + self._on_user_chooses_option, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + + def disconnect(self) -> None: + # Error out ongoing requests + with self._request_lock: + if self._service: + self._node.destroy_service(self._service) + self._service = None + + if self._ongoing_request: + msg = "A system exit was requested while waiting for user input!" + try: + self._ongoing_request.future.set_exception(RuntimeError(msg)) + except InvalidStateError: + pass + self._ongoing_request = None diff --git a/pkgs/node_helpers/node_helpers/launching/README.md b/pkgs/node_helpers/node_helpers/launching/README.md new file mode 100644 index 0000000..6c9e8bc --- /dev/null +++ b/pkgs/node_helpers/node_helpers/launching/README.md @@ -0,0 +1,7 @@ +# node_helpers.launching + +The `node_helpers.launching` module provides utilities to simplify ROS launch file management, including dynamic node swapping, and file validation. + +It supports the swapping of "real" and "mock" nodes during launch via the `SwappableNode` class and `apply_node_swaps` function. Additionally, it includes file validation utilities (`required_file` and `required_directory`) and URDF manipulation functions for namespace application and name validation. + +Full documentation can be found under [docs/](../../../../docs/launching.rst). diff --git a/pkgs/node_helpers/node_helpers/launching/__init__.py b/pkgs/node_helpers/node_helpers/launching/__init__.py new file mode 100644 index 0000000..49bd0dc --- /dev/null +++ b/pkgs/node_helpers/node_helpers/launching/__init__.py @@ -0,0 +1,8 @@ +from .files import required_directory, required_file +from .swappable_nodes import ( + InvalidSwapConfiguration, + SwapConfiguration, + SwappableNode, + apply_node_swaps, +) +from .urdf import fix_urdf_paths, prepend_namespace diff --git a/pkgs/node_helpers/node_helpers/launching/files.py b/pkgs/node_helpers/node_helpers/launching/files.py new file mode 100644 index 0000000..c71289d --- /dev/null +++ b/pkgs/node_helpers/node_helpers/launching/files.py @@ -0,0 +1,15 @@ +from pathlib import Path + + +def required_directory(*elements: str | Path) -> Path: + path = Path(*elements) + if not path.is_dir(): + raise FileNotFoundError(f"Directory {path} not found") + return path + + +def required_file(*elements: str | Path) -> Path: + path = Path(*elements) + if not path.is_file(): + raise FileNotFoundError(f"File {path} not found") + return path diff --git a/pkgs/node_helpers/node_helpers/launching/swappable_nodes.py b/pkgs/node_helpers/node_helpers/launching/swappable_nodes.py new file mode 100644 index 0000000..5c643f6 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/launching/swappable_nodes.py @@ -0,0 +1,102 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +from launch import LaunchDescriptionEntity +from launch_ros.actions import Node +from pydantic import BaseModel + + +class InvalidSwapConfiguration(Exception): + pass + + +class SwapConfiguration(BaseModel): + mock: str + real: str + enable_mock: bool + + +class SwappableNode(Node): + """Unfortunately there's no way to know what the 'name' or 'namespace' of a node is + by looking at the Node.name or Node.namespace. So we subclass node and keep track + of those attributes in the __init__ + + If this node exits early during operation, the entire system will shut down. + """ + + def __init__(self, *, namespace: str, name: str, **kwargs: Any): + super().__init__(namespace=namespace, name=name, **kwargs) + self.swap_namespace = namespace + self.swap_name = name + + def should_use(self, configuration: SwapConfiguration) -> bool: + if configuration.real == self.swap_name: + if configuration.enable_mock: + return False + return True + elif configuration.mock == self.swap_name: + if configuration.enable_mock: + return True + return False + else: + raise InvalidSwapConfiguration( + "This function should only be fed relevant SwapConfiguration. " + f"Received {configuration=} for {self.swap_name=}" + ) + + +def apply_node_swaps( + configuration: dict[str, SwapConfiguration], + launch_description: Iterable[LaunchDescriptionEntity], +) -> list[LaunchDescriptionEntity]: + """This system is used to swap "node A" or "node B" using configuration. This is + most useful when there's a 1:1 correspondence between nodes and their respective + mocks. + + :param configuration: A dictionary of {'namespace': MockConfiguration} + :param launch_description: Output from a normal generate_launch_description + :raises InvalidSwapConfiguration: If theres any incorrect configuration either in + the swap configuration, or if nodes are missing that should have existed. + :return: A list of filtered launch entities, with swaps applied + """ + filtered_launch_description = [] + seen_referenced_nodes: dict[str, list[str]] = defaultdict(list) + """Keep track of all referenced namespace:node_name pairs in the configuration. + If any aren't fully described after looking through the launch description, raise + an exception. + """ + + for element in launch_description: + if not isinstance(element, SwappableNode): + filtered_launch_description.append(element) + continue + + # Track the nodes name + seen_referenced_nodes[element.swap_namespace].append(element.swap_name) + + # Validate configuration contains the name + if element.swap_namespace not in configuration: + raise InvalidSwapConfiguration( + f"Expected swap configuration for " + f"{element.swap_namespace}.{element.swap_name} but was given none! " + f"{configuration=}" + ) + + # Decide if the node should be kept + swap_configuration = configuration[element.swap_namespace] + if element.should_use(swap_configuration): + filtered_launch_description.append(element) + + # Validate all nodes that were referenced in configuration were seen + if len(seen_referenced_nodes.keys()) != len(configuration.keys()): + raise InvalidSwapConfiguration( + "Not all namespaces referenced in the swap configuration were seen in the " + "launch file description!" + ) + # Validate all seen namespaces contain two nodes that would have been swapped + for seen_swappable_nodes in seen_referenced_nodes.values(): + if len(seen_swappable_nodes) != 2: + raise InvalidSwapConfiguration("Swappable nodes should come in pairs!") + + return filtered_launch_description diff --git a/pkgs/node_helpers/node_helpers/markers/__init__.py b/pkgs/node_helpers/node_helpers/markers/__init__.py new file mode 100644 index 0000000..876342f --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/__init__.py @@ -0,0 +1,5 @@ +from .arrow import create_point_to_point_arrow_marker +from .colors import color_from_seed, color_msg_from_seed +from .interactive_marker import InteractiveVolumeMarker +from .marker_arrays import ascending_id_marker_array +from .text import create_floating_text diff --git a/pkgs/node_helpers/node_helpers/markers/arrow.py b/pkgs/node_helpers/node_helpers/markers/arrow.py new file mode 100644 index 0000000..093a9ce --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/arrow.py @@ -0,0 +1,40 @@ +from typing import Any + +import numpy as np +import numpy.typing as npt +from geometry_msgs.msg import Point, Vector3 +from visualization_msgs.msg import Marker + + +def create_point_to_point_arrow_marker( + shaft_diameter: float, + head_diameter: float, + head_length: float, + base_point: tuple[float, float, float] | npt.NDArray[np.float64] | Vector3 | Point, + head_point: tuple[float, float, float] | npt.NDArray[np.float64] | Vector3 | Point, + **marker_kwargs: Any, +) -> Marker: + """Create an arrow marker using a base point to a head point. + + :param shaft_diameter: The diameter of the arrow shaft + :param head_diameter: The diameter of the arrow head + :param head_length: The length of the arrow head + :param base_point: The base point of the arrow + :param head_point: The head point of the arrow + :param marker_kwargs: The arguments to pass on to the Marker + + :return: The arrow marker + """ + if isinstance(base_point, Vector3 | Point): + base_point = (base_point.x, base_point.y, base_point.z) + if isinstance(head_point, Vector3 | Point): + head_point = (head_point.x, head_point.y, head_point.z) + + return Marker( + type=Marker.ARROW, + # x: Shaft diameter, Y: head diameter, Z: head length, as per: + # http://wiki.ros.org/rviz/DisplayTypes/Marker#Arrow_.28ARROW.3D0.29 + scale=Vector3(x=shaft_diameter, y=head_diameter, z=head_length), + points=[Point(x=p[0], y=p[1], z=p[2]) for p in [base_point, head_point]], + **marker_kwargs, + ) diff --git a/pkgs/node_helpers/node_helpers/markers/colors.py b/pkgs/node_helpers/node_helpers/markers/colors.py new file mode 100644 index 0000000..c926483 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/colors.py @@ -0,0 +1,15 @@ +import random + +from std_msgs.msg import ColorRGBA + + +def color_from_seed(seed: str | float) -> tuple[float, float, float]: + """Consistently generate a random color from a given seed, as 0-1 floats.""" + rand = random.Random(seed) + color = (rand.random(), rand.random(), rand.random()) + return color + + +def color_msg_from_seed(seed: str | float, alpha: float = 1.0) -> ColorRGBA: + color = color_from_seed(seed) + return ColorRGBA(r=color[0], g=color[1], b=color[2], a=alpha) diff --git a/pkgs/node_helpers/node_helpers/markers/interactive_marker.py b/pkgs/node_helpers/node_helpers/markers/interactive_marker.py new file mode 100644 index 0000000..c14fbf0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/interactive_marker.py @@ -0,0 +1,101 @@ +from geometry_msgs.msg import Point, Pose, Quaternion, Vector3 +from std_msgs.msg import ColorRGBA, Header +from visualization_msgs.msg import InteractiveMarker, InteractiveMarkerControl, Marker + + +class InteractiveVolumeMarker: + _axis = [ + ( + Quaternion(w=1.0, x=1.0, y=0.0, z=0.0), + "rotate_x", + InteractiveMarkerControl.ROTATE_AXIS, + ), + ( + Quaternion(w=1.0, x=0.0, y=1.0, z=0.0), + "rotate_z", + InteractiveMarkerControl.ROTATE_AXIS, + ), + ( + Quaternion(w=1.0, x=0.0, y=0.0, z=1.0), + "rotate_y", + InteractiveMarkerControl.ROTATE_AXIS, + ), + ( + Quaternion(w=1.0, x=1.0, y=0.0, z=0.0), + "move_x", + InteractiveMarkerControl.MOVE_AXIS, + ), + ( + Quaternion(w=1.0, x=0.0, y=1.0, z=0.0), + "move_z", + InteractiveMarkerControl.MOVE_AXIS, + ), + ( + Quaternion(w=1.0, x=0.0, y=0.0, z=1.0), + "move_y", + InteractiveMarkerControl.MOVE_AXIS, + ), + ] + + def __init__( + self, + name: str, + frame_id: str, + scale: list[float], + description: str = "", + position: Point = None, + orientation: Quaternion = None, + fixed: bool = False, + show_6dof: bool = True, + ): + """ + :param name: The visual name of the marker + :param frame_id: The frame ID the marker will move inside + :param scale: The [x, y, z] scaling of the marker + :param description: The description that will show up in rviz + :param position: The initial position of the marker + :param orientation: The initial orientation of the marker. + :param fixed: Whether the marker is movable + :param show_6dof: Whether to display the 6dof controls + """ + initial_pose = Pose( + position=position or Point(), orientation=orientation or Quaternion() + ) + + self.box = Marker( + scale=Vector3(x=scale[0], y=scale[1], z=scale[2]), + color=ColorRGBA(r=0.5, g=0.5, b=0.5, a=0.5), + header=Header(frame_id=frame_id), + pose=initial_pose, + type=Marker.CUBE, + ) + + self.box_control = InteractiveMarkerControl( + always_visible=True, markers=[self.box] + ) + + self.interactive_marker = InteractiveMarker( + pose=initial_pose, + # Make the controls encompass the whole of the box + scale=max(scale) * 2.25, + header=self.box.header, + name=name, + description=description, + controls=[self.box_control], + ) + + # Generate 6dof controls + if show_6dof: + for ( + orientation, + control_name, + axis_interaction, + ) in self._axis: + control = InteractiveMarkerControl( + orientation=orientation, + name=control_name, + interaction_mode=axis_interaction, + ) + if fixed: + control = InteractiveMarkerControl.FIXED + self.interactive_marker.controls.append(control) diff --git a/pkgs/node_helpers/node_helpers/markers/marker_arrays.py b/pkgs/node_helpers/node_helpers/markers/marker_arrays.py new file mode 100644 index 0000000..bfac7c5 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/marker_arrays.py @@ -0,0 +1,26 @@ +from visualization_msgs.msg import Marker, MarkerArray + + +def ascending_id_marker_array( + markers: list[Marker], + start_id: int = 0, + delete_existing: bool = True, + marker_namespace: str = "", +) -> MarkerArray: + """Return the same list of markers modified in-place so that each marker has + ascending IDs. + :param markers: The markers to modify in-place + :param start_id: The ID to begin at + :param delete_existing: If True, all markers for this topic will be cleared before + the new markers are added + :param marker_namespace: The namespace to create markers under + :returns: A MarkerArray with markers that have ascending IDs + """ + + if delete_existing: + markers.insert(0, Marker(action=Marker.DELETEALL, ns=marker_namespace)) + + for i, marker in enumerate(markers, start=start_id): + marker.id = i + marker.ns = marker_namespace + return MarkerArray(markers=markers) diff --git a/pkgs/node_helpers/node_helpers/markers/text.py b/pkgs/node_helpers/node_helpers/markers/text.py new file mode 100644 index 0000000..aa5e137 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/markers/text.py @@ -0,0 +1,18 @@ +from typing import Any + +from geometry_msgs.msg import Pose, Vector3 +from visualization_msgs.msg import Marker + + +def create_floating_text( + text: str, text_height: float, pose: Pose, **marker_kwargs: Any +) -> Marker: + return Marker( + type=Marker.TEXT_VIEW_FACING, + # Spaces are replaced with underscores because of a weird rviz bug where + # spaces make for absolutely gargantuan spaces between words + text=text.replace(" ", "_"), + scale=Vector3(z=text_height), + pose=pose, + **marker_kwargs, + ) diff --git a/pkgs/node_helpers/node_helpers/nodes/__init__.py b/pkgs/node_helpers/node_helpers/nodes/__init__.py index 6492a19..bc073e5 100644 --- a/pkgs/node_helpers/node_helpers/nodes/__init__.py +++ b/pkgs/node_helpers/node_helpers/nodes/__init__.py @@ -1 +1,16 @@ -from .node_helpers_node import ExampleNode \ No newline at end of file +from .helpful_node import HelpfulNode +from .interactive_transform_publisher import ( + InteractiveTransformPublisher, + TransformDescription, + TransformModel, + TransformsFile, +) +from .sound_player import SoundPlayer + +__all__ = [ + "HelpfulNode", + "InteractiveTransformPublisher", + "TransformModel", + "TransformsFile", + "TransformDescription", +] diff --git a/pkgs/node_helpers/node_helpers/nodes/helpful_node.py b/pkgs/node_helpers/node_helpers/nodes/helpful_node.py new file mode 100644 index 0000000..442dc5b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/helpful_node.py @@ -0,0 +1,17 @@ +from rclpy.node import Node + +from node_helpers.destruction import DestroyCallbacksMixin +from node_helpers.parameters import ParameterMixin +from node_helpers.robust_rpc import RobustRPCMixin +from node_helpers.timing import SingleShotMixin, TimerWithWarningsMixin + + +class HelpfulNode( + ParameterMixin, + RobustRPCMixin, + DestroyCallbacksMixin, + SingleShotMixin, + TimerWithWarningsMixin, + Node, +): + """This node class combines all helper mixins within node_helpers into one node""" diff --git a/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/README.md b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/README.md new file mode 100644 index 0000000..57d4f86 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/README.md @@ -0,0 +1,14 @@ +# interactive_transform_publisher + +The `interactive_transform_publisher` is a ROS2 node for interactively manipulating and publishing static transforms using RViz. It simplifies calibration and configuration of transforms for URDFs or static TF trees. + +### Features +- **Interactive Editing**: Manipulate transforms visually in RViz using interactive markers. +- **Persistence**: Save and load transforms from a configuration file for reuse across sessions. +- **API for Adjustments**: Modify and create transforms programmatically using provided topics and client utilities. + +### Key Topics +- `/tf_static_updates`: Update existing transforms. +- `/tf_static_create`: Create new transforms dynamically. + +For full documentation, see [interactive_transform_publisher.rst](../../../../docs/interactive_transform_publisher.rst). diff --git a/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/__init__.py b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/__init__.py new file mode 100644 index 0000000..ca33086 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/__init__.py @@ -0,0 +1,302 @@ +import logging +from functools import partial +from pathlib import Path +from typing import Any + +import numpy as np +from builtin_interfaces.msg import Time +from geometry_msgs.msg import Point, Pose, Quaternion, TransformStamped +from interactive_markers import InteractiveMarkerServer +from visualization_msgs.msg import InteractiveMarkerFeedback + +from node_helpers.markers import InteractiveVolumeMarker +from node_helpers.nodes.helpful_node import HelpfulNode +from node_helpers.parameters import param_path +from node_helpers.qos import qos_profile_transient_reliable_keep_all +from node_helpers.ros2_numpy import msgify, numpify +from node_helpers.spinning import create_spin_function +from node_helpers.tf import ConstantStaticTransformBroadcaster + +from .exceptions import DuplicateTransformError, MultipleParentsError +from .schemas import TransformDescription, TransformModel, TransformsFile + + +class InteractiveTransformPublisher(HelpfulNode): + """ + An interactive transform publisher node for ROS 2 that provides a way to + interactively manipulate and publish static transforms using RViz. + + This node allows users to easily calibrate and glue URDFs together by exposing + interactive markers in RViz for manipulating transforms. It reads and writes + transforms to a specified file, allowing easy persistence of transforms across + sessions. + + The node exposes the following topics: + * `/tf_static_updates`: A topic to listen for external updates to transforms. + If a TF is published here that doesn't yet exist, it will be ignored. + * `/tf_static_create`: A topic for creating new TFs, if they didn't already exist. + If a TF is published here that already existed (e.g. from a file or ROS param + config), no change will be made. + + To use this node, make sure to have an appropriate configuration that specifies + the required parameters such as static_transforms_file, transforms, + tf_publish_frequency, and scale_factor. + """ + + def __init__(self, **kwargs: Any): + super().__init__("interactive_transform_publisher", **kwargs) + # Set up parameters + self.transforms_path = param_path( + self.declare_and_get_parameter( + "static_transforms_file", type_=str, required=True, description="" + ) + ) + initial_transforms: list[str] = self.declare_and_get_parameter( + "transforms", + type_=list[str], + required=True, + description="A list of strings in the format of " + "'parent_tf_name:child_tf_name', which will turn into zeroed out " + "transforms that can then be interacted with.", + ) + self.publish_seconds = 1 / self.declare_and_get_parameter( + "tf_publish_frequency", + type_=float, + required=True, + description="The frequency (HZ) to publish to TF static", + ) + self.scale_factor = self.declare_and_get_parameter( + "scale_factor", + type_=float, + default_value=1.0, + description="How much to scale movements by", + ) + + # Set up the interactive marker server + self.interaction_server = InteractiveMarkerServer( + node=self, + namespace=self.get_namespace(), + ) + + # Set up a receiver for TF changes + self.create_subscription( + TransformStamped, + "tf_static_updates", + callback=self._on_update_transform, + qos_profile=qos_profile_transient_reliable_keep_all, + ) + + # Set up a receiver for TF creation (for defining TFs that don't exist in ROS + # params or in the state file) + self.create_subscription( + TransformStamped, + "tf_static_create", + callback=self._on_create_transform, + qos_profile=qos_profile_transient_reliable_keep_all, + ) + + # Register transforms from the state file into self.transforms and the server + self.transforms: list[TransformDescription] = [] + self._load_state_file(initial_transforms, self.transforms_path) + + def _load_state_file( + self, allowed_transform_pairs: list[str], transform_path: Path + ) -> None: + """Load transforms from a file into self.transforms and the interactive marker + server. + This function ensures that 'stale' transforms in the file that are not specified + in configuration (and were not created via API) are removed from the file. + + :param allowed_transform_pairs: A list of strings of format 'parent_tf:child_tf' + :param transform_path: A path to a file that follows the pydantic TransformsFile + schema. + """ + + # First, we start by making identity transforms, as specified by ROS params + allowed_transforms: dict[str, TransformModel] = { + key: TransformModel(parent=p, child=c, created_via_api=False) + for key, p, c in ((s, *s.split(":")) for s in allowed_transform_pairs) + } + if transform_path.is_file(): + file_transforms = TransformsFile.model_validate_json( + transform_path.read_text() + ).transforms + + # The following logic loads transforms only if either of these are true: + # 1) They were specified in ROS parameters + # 2) They were in the transforms file AND marked as having been created + # via API + for loaded_transform in file_transforms: + key = f"{loaded_transform.parent}:{loaded_transform.child}" + if key in allowed_transforms: + # Ensure that parameter specified transforms are marked as such + loaded_transform.created_via_api = False + allowed_transforms[key] = loaded_transform + elif loaded_transform.created_via_api: + allowed_transforms[key] = loaded_transform + + for model in allowed_transforms.values(): + self._register_transform(model) + + def _register_transform(self, model: TransformModel) -> TransformDescription: + # First, validate this transform isn't a duplicate that's already been added + existing_marker_names: set[str] = {t.model.marker_name for t in self.transforms} + if model.marker_name in existing_marker_names: + raise DuplicateTransformError( + f"Tried to register {model.marker_name}, which already exists!" + ) + + # Next, validate this transforms child doesn't already have a registered parent + existing_children: set[str] = {t.model.child for t in self.transforms} + if model.child in existing_children: + raise MultipleParentsError( + f"Tried to register {model.child} as a child to {model.parent}, but " + f"it already has a different parent configured in the interactive " + f"transforms configuration! It's possible this resulted from changing " + f"yaml configuration multiple times with different transforms. " + f"To fix this You can either edit the configuration file directly, or " + f"delete it and relaunch the stack.\n" + f"Bad configuration file path: {self.transforms_path}" + ) + + # Create a ConstantStaticTransformPublisher, and reference it in self.transforms + transform_description = TransformDescription( + model=model, + broadcaster=ConstantStaticTransformBroadcaster( + self, publish_seconds=self.publish_seconds + ), + ) + self.transforms.append(transform_description) + + # Set the transform for the broadcaster + self.get_logger().info(f"Publishing static transform {model}") + transform_description.broadcaster.set_transform(model.to_msg(stamp=Time())) + + # Add this to the interactive transform server + self._insert_interactive_transform(transform_description) + return transform_description + + def _insert_interactive_transform( + self, transform_description: TransformDescription + ) -> None: + """Insert a transform without applying changes on the server""" + transform = transform_description.model + marker = InteractiveVolumeMarker( + name=transform_description.model.marker_name, + description=transform.child, + scale=[0.05, 0.05, 0.05], + frame_id=transform.parent, + position=msgify(Point, np.array(transform.translation) / self.scale_factor), + orientation=msgify(Quaternion, np.array(transform.rotation)), + ) + self.interaction_server.insert( + marker.interactive_marker, + feedback_callback=partial(self._feedback, transform_description), + ) + self.interaction_server.applyChanges() + + def _feedback( + self, + transform_description: TransformDescription, + feedback: InteractiveMarkerFeedback, + ) -> None: + """Called when a user moves a transform in rviz""" + transform = transform_description.model + + # Get the interpolated translation + position = numpify(feedback.pose.position) + previous_position = transform.translation + pos_delta = position - previous_position + interpolated_transform = ( + (previous_position + pos_delta) * self.scale_factor + ).tolist() + + # Publish a static transform + transform.rotation = tuple(numpify(feedback.pose.orientation).tolist()) + transform.translation = tuple(interpolated_transform) + msg = transform.to_msg(stamp=self.get_clock().now().to_msg()) + + if feedback.event_type == InteractiveMarkerFeedback.MOUSE_UP: + self._save() + + transform_description.broadcaster.set_transform(msg) + + def _on_update_transform(self, transform_msg: TransformStamped) -> None: + """Publishes a transform and saves it to the file""" + + # Try to find a matching transform description to update + try: + transform_description = next( + t + for t in self.transforms + if t.model.parent == transform_msg.header.frame_id + and t.model.child == transform_msg.child_frame_id + ) + except StopIteration: + self.get_logger().error( + "The transform description does not exist for the parent frame " + f"'{transform_msg.header.frame_id}' and child frame " + f"'{transform_msg.child_frame_id}'. If you would like to create a new " + f"transform via API, publish it on the 'tf_static_create' topic." + ) + return + + # Edit parameters in place + rotation_np = numpify(transform_msg.transform.rotation) + translation_np = numpify(transform_msg.transform.translation) + transform_description.model.rotation = tuple(rotation_np.tolist()) + transform_description.model.translation = tuple(translation_np.tolist()) + + # Save changes to file + self._save() + + # Update the interactive marker so it reflects the new state. + # This prevents a bug where after automatic calibration, when a user touches + # a transform on RVIZ it would 'reset' back to the position last known by the + # interaction server. + marker_name = transform_description.model.marker_name + assert self.interaction_server.setPose( + transform_description.model.marker_name, + pose=Pose( + orientation=msgify(Quaternion, rotation_np), + position=msgify(Point, translation_np / self.scale_factor), + ), + ) + self.interaction_server.applyChanges() + + # Send the transform + transform_description.broadcaster.set_transform(transform_msg) + self.get_logger().info(f"Successfully updated the transform '{marker_name}'") + + def _on_create_transform(self, transform_msg: TransformStamped) -> None: + """Create a transform if it didn't already exist, otherwise modify nothing.""" + try: + self._register_transform( + TransformModel( + created_via_api=True, + parent=transform_msg.header.frame_id, + child=transform_msg.child_frame_id, + rotation=tuple(numpify(transform_msg.transform.rotation).tolist()), + translation=tuple( + numpify(transform_msg.transform.translation).tolist() + ), + ) + ) + except DuplicateTransformError: + self.get_logger().info( + f"Not creating transform {transform_msg}, it already exists" + ) + else: + self._on_update_transform(transform_msg) + + def _save(self) -> None: + """Update the transforms file""" + self.get_logger().error(f"Saving transforms to file: {self.transforms_path}") + self.transforms_path.write_text( + TransformsFile( + transforms=[t.model for t in self.transforms] + ).model_dump_json() + ) + + +main = create_spin_function(InteractiveTransformPublisher, multi_threaded=True) diff --git a/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/client.py b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/client.py new file mode 100644 index 0000000..355fe44 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/client.py @@ -0,0 +1,106 @@ +import logging + +import numpy as np +import numpy.typing as npt +from builtin_interfaces.msg import Time +from geometry_msgs.msg import Transform, TransformStamped +from rclpy.callback_groups import CallbackGroup, MutuallyExclusiveCallbackGroup +from std_msgs.msg import Header +from tf2_ros import Buffer + +from node_helpers.nodes import HelpfulNode +from node_helpers.qos import qos_profile_transient_reliable_keep_all +from node_helpers.ros2_numpy import msgify, numpify + + +class InteractiveTransformClient: + """A helper object for calling APIs on the InteractiveTransformPublisher""" + + def __init__( + self, + node: HelpfulNode, + namespace: str | None = None, + callback_group: CallbackGroup | None = None, + ): + callback_group = callback_group or MutuallyExclusiveCallbackGroup() + self.namespace = namespace or node.get_namespace().replace("/", "") + + self.update_transform = node.create_publisher( + TransformStamped, + f"/{self.namespace}/tf_static_updates", + qos_profile=qos_profile_transient_reliable_keep_all, + callback_group=callback_group, + ) + self.create_transform = node.create_publisher( + TransformStamped, + f"/{self.namespace}/tf_static_create", + qos_profile=qos_profile_transient_reliable_keep_all, + callback_group=callback_group, + ) + + def adjust_transform(self, tf_buffer: Buffer, adjustment: TransformStamped) -> None: + """Apply an adjustment to the transform. This is useful for calibration.""" + self.adjust_transform_np( + tf_buffer=tf_buffer, + parent=adjustment.header.frame_id, + child=adjustment.child_frame_id, + adjustment=numpify(adjustment.transform), + ) + + def adjust_transform_np( + self, + tf_buffer: Buffer, + parent: str, + child: str, + adjustment: npt.NDArray[np.float64], + relative_to: str | None = None, + ) -> None: + """Apply 4x4 adjustment to the transform. This is useful for calibration. + + :param tf_buffer: The buffer to look up the current transform + :param parent: The parent frame of the frame to move + :param child: The child frame to move + :param adjustment: The 4x4 adjustment matrix to apply + :param relative_to: The frame that 'adjustment' is in. By default, this is in + 'child' space. For example, if the adjustment is in 'frame_c' space, and is + == np.eye(4), then the child frame will be moved to the same position as + 'frame_c'. + """ + current_transform = numpify( + tf_buffer.lookup_transform( + source_frame=relative_to or child, target_frame=parent, time=Time() + ).transform + ) + adjusted_transform = adjustment @ current_transform + + logging.error( + f"Applying adjustment to {parent}->{child} of {adjustment.tolist()}" + ) + adjustment_msg = TransformStamped( + header=Header(frame_id=parent), + child_frame_id=child, + transform=msgify(Transform, adjusted_transform), + ) + self.update_transform.publish(adjustment_msg) + + def adjust_transform_along_axes( + self, + tf_buffer: Buffer, + parent: str, + child: str, + x: float = 0.0, + y: float = 0.0, + z: float = 0.0, + relative_to: str | None = None, + ) -> None: + adjustment = np.eye(4) + adjustment[0, 3] = x + adjustment[1, 3] = y + adjustment[2, 3] = z + return self.adjust_transform_np( + tf_buffer=tf_buffer, + parent=parent, + child=child, + adjustment=adjustment, + relative_to=relative_to, + ) diff --git a/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/exceptions.py b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/exceptions.py new file mode 100644 index 0000000..ce8eee8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/exceptions.py @@ -0,0 +1,6 @@ +class DuplicateTransformError(Exception): + """Raised when a transform is created that already exists""" + + +class MultipleParentsError(Exception): + """Raised when a child TF is configured to have multiple parents""" diff --git a/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/schemas.py b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/schemas.py new file mode 100644 index 0000000..d8de43a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/interactive_transform_publisher/schemas.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +import numpy as np +from builtin_interfaces.msg import Time +from geometry_msgs.msg import Quaternion, Transform, TransformStamped, Vector3 +from pydantic import BaseModel +from std_msgs.msg import Header + +from node_helpers.ros2_numpy import msgify +from node_helpers.tf import ConstantStaticTransformBroadcaster + + +class TransformModel(BaseModel): + """A serializable Transform""" + + parent: str + child: str + created_via_api: bool = False + """If True, this transform was created via the ROS API of this node. + If False, this transform originated from a parameters file. + + This information is used to automatically remove stale transforms from the file + if they are removed from configuration, but keep them if they were created via api. + """ + + translation: tuple[float, float, float] = (0.0, 0.0, 0.0) + rotation: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0) + + def to_msg(self, stamp: Time) -> TransformStamped: + return TransformStamped( + header=Header(frame_id=self.parent, stamp=stamp), + child_frame_id=self.child, + transform=Transform( + translation=msgify(Vector3, np.array(self.translation)), + rotation=msgify(Quaternion, np.array(self.rotation)), + ), + ) + + @property + def marker_name(self) -> str: + return f"Transform {self.parent} -> {self.child}" + + +@dataclass +class TransformDescription: + """A state object to hold the Transform and its respective static broadcaster""" + + model: TransformModel + broadcaster: ConstantStaticTransformBroadcaster + + +class TransformsFile(BaseModel): + """The configuration file format""" + + transforms: list[TransformModel] diff --git a/pkgs/node_helpers/node_helpers/nodes/node_helpers_node.py b/pkgs/node_helpers/node_helpers/nodes/node_helpers_node.py deleted file mode 100644 index 9cddedf..0000000 --- a/pkgs/node_helpers/node_helpers/nodes/node_helpers_node.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Any - -from node_helpers.nodes import HelpfulNode -from pydantic import BaseModel -from std_msgs.msg import String -from node_helpers.spinning import create_spin_function -from rclpy.qos import qos_profile_services_default - -class ExampleNode(HelpfulNode): - - class Parameters(BaseModel): - # Define your ROS parameters here - publish_value: str - publish_hz: float - - def __init__(self, **kwargs: Any): - super().__init__("ExampleNode", **kwargs) - # Load parameters from the ROS parameter server - self.params = self.declare_from_pydantic_model(self.Parameters, "config") - - # Create a publisher - self.publisher = self.create_publisher( - String, "example_topic", qos_profile=qos_profile_services_default - ) - - # Create a timer - self.create_timer(1 / self.params.publish_hz, self.on_publish) - - - def on_publish(self) -> None: - msg = String() - msg.data = self.params.publish_value - self.publisher.publish(msg) - -main = create_spin_function(ExampleNode) \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/nodes/placeholder.py b/pkgs/node_helpers/node_helpers/nodes/placeholder.py new file mode 100644 index 0000000..1ade6cd --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/placeholder.py @@ -0,0 +1,16 @@ +from typing import Any + +from node_helpers.nodes import HelpfulNode +from node_helpers.spinning import create_spin_function + + +class Placeholder(HelpfulNode): + """A placeholder node that can be used in launch files to keep them running even + if no other nodes are present + """ + + def __init__(self, **kwargs: Any): + super().__init__("placeholder", **kwargs) + + +main = create_spin_function(Placeholder) diff --git a/pkgs/node_helpers/node_helpers/nodes/sound_player/__init__.py b/pkgs/node_helpers/node_helpers/nodes/sound_player/__init__.py new file mode 100644 index 0000000..d85a5d0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/sound_player/__init__.py @@ -0,0 +1,54 @@ +import logging +import subprocess +from pathlib import Path +from typing import Any + +from node_helpers_msgs.msg import PlaySound +from pydantic import BaseModel, DirectoryPath +from rclpy.qos import qos_profile_services_default + +from node_helpers.nodes.helpful_node import HelpfulNode +from node_helpers.spinning import create_spin_function + + +class SoundPlayer(HelpfulNode): + """ + A generic node for requesting sound effects to be played via topics. + """ + + class Parameters(BaseModel): + sound_effects_directory: DirectoryPath = Path("config/common/sound_effects/") + """A directory with *.ogg sound files""" + + def __init__(self, **kwargs: Any): + super().__init__("sound_player", **kwargs) + + self.params = self.declare_from_pydantic_model(self.Parameters, "player_config") + self.create_subscription( + PlaySound, + "play_sound", + self.on_play_sound, + qos_profile=qos_profile_services_default, + ) + + def on_play_sound(self, msg: PlaySound) -> None: + file_path = self.params.sound_effects_directory / msg.sound_filename + if not file_path.exists(): + raise FileNotFoundError( + f"The specified filename '{msg.sound_filename}' does" + " not exist in the configured sound directory: " + f"{self.params.sound_effects_directory}" + ) + + command = ["paplay", str(file_path)] + + logging.info(f"Playing sound using command: {command}") + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError: + logging.exception( + "Sound Player encountered error while attempting to play sound" + ) + + +main = create_spin_function(SoundPlayer) diff --git a/pkgs/node_helpers/node_helpers/nodes/sound_player/client.py b/pkgs/node_helpers/node_helpers/nodes/sound_player/client.py new file mode 100644 index 0000000..4d41113 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/nodes/sound_player/client.py @@ -0,0 +1,25 @@ +from node_helpers_msgs.msg import PlaySound +from rclpy.callback_groups import CallbackGroup, MutuallyExclusiveCallbackGroup +from rclpy.qos import qos_profile_services_default + +from node_helpers.nodes import HelpfulNode + + +class SoundPlayerClient: + """A helper object for calling APIs on the SoundPlayer""" + + def __init__( + self, + node: HelpfulNode, + namespace: str | None = None, + callback_group: CallbackGroup | None = None, + ): + callback_group = callback_group or MutuallyExclusiveCallbackGroup() + self.namespace = namespace or node.get_namespace().replace("/", "") + + self.play_sound = node.create_publisher( + PlaySound, + f"/{self.namespace}/play_sound", + qos_profile=qos_profile_services_default, + callback_group=callback_group, + ) diff --git a/pkgs/node_helpers/node_helpers/parameters/README.md b/pkgs/node_helpers/node_helpers/parameters/README.md new file mode 100644 index 0000000..b5797a2 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/README.md @@ -0,0 +1,7 @@ +# node_helpers.parameters + +This module contains a framework for managing and updating parameters in your ROS nodes. +By adding the `ParameterMixin` you can easily add parameters to your node and update them at runtime. +It also contains tools for loading, 'rendering', and passing parameters to nodes in launch files. + +The full documentation can be found under [docs/](../../../../docs/parameters.rst). \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/parameters/__init__.py b/pkgs/node_helpers/node_helpers/parameters/__init__.py new file mode 100644 index 0000000..ff4880c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/__init__.py @@ -0,0 +1,17 @@ +from .choosable_object import ( + Choosable, + DuplicateRegistrationError, + UnregisteredChoosableError, +) +from .loading import ( + FIELD_PLACEHOLDER, + Namespace, + ParameterLoader, + ParameterLoadingError, +) +from .parameter_mixin import ( + ParameterMixin, + RequiredParameterNotSetException, + UnfilledParametersFileError, +) +from .path import param_path diff --git a/pkgs/node_helpers/node_helpers/parameters/choosable_object.py b/pkgs/node_helpers/node_helpers/parameters/choosable_object.py new file mode 100644 index 0000000..18edef2 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/choosable_object.py @@ -0,0 +1,211 @@ +""" +In short, this module provides a seamless way to select Classes and Instances via +configuration, by having a global registration system and an easy way to register and +retrieve them. + +Take a look at the node_helpers/parameters documentation for full usage instructions. +""" + +from collections import defaultdict +from typing import Any, TypeVar, cast + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +SomeBaseClassType = TypeVar("SomeBaseClassType", bound="Choosable") + +# Define the global registry of uninstantiated classes +ClassRegistryType = defaultdict[type["Choosable"], dict[str, type["Choosable"]]] +_global_choosable_class_registry: ClassRegistryType = defaultdict(dict) +"""Global registry of all classes that can be loaded from configuration. + +This is a dictionary of dictionaries. The first key is the base class of the class +being loaded. The second key is the name of the class being loaded. The value is the +class itself. +""" + +# Define the global registry of instantiated instances +InstanceRegistryType = defaultdict[type["Choosable"], dict[str, "Choosable"]] +_global_choosable_instance_registry: InstanceRegistryType = defaultdict(dict) +"""Global registry of all instances that can be loaded from configuration.""" + + +class UnregisteredChoosableError(Exception): + """Used when a class or instance is not registered in the global registry""" + + +class DuplicateRegistrationError(Exception): + """Used when a class or instance is registered but the name has already been used""" + + +class Choosable: + """Base Class for any class that wants to be dynamically choosable by a user in + configuration.""" + + def __init_subclass__( + cls, registered_name: str | None = None, **kwargs: Any + ) -> None: + """ + :param registered_name: The name of the class that was registered. This is + the name that should be used in configuration. + By default this is the classes name, but subclasses can override + this. + :param kwargs: Other metaclass parameters, if any + :raises DuplicateRegistrationError: if the class was already registered + """ + super().__init_subclass__(**kwargs) + registered_name = registered_name or cls.__name__ + base_choosable_class = cls._find_base_choosable_class() + registry_for_class = _global_choosable_class_registry[base_choosable_class] + + if registered_name in registry_for_class: + raise DuplicateRegistrationError( + f"A class with the name {registered_name} has already been registered " + f"for {base_choosable_class.__name__}. Make sure you're not declaring " + f"multiple classes with the same name and same base class." + ) + registry_for_class[registered_name] = cls + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """This allows pydantic to accept instances of Choosable as + valid configuration values. Otherwise, you would have to specify the + 'allow_arbitrary_types' for any pydantic model that declares an instance of a + Choosable. + + Read more here: + https://docs.pydantic.dev/latest/concepts/types/#custom-types + # noqa: DAR101 + # noqa: DAR201 + """ + return core_schema.is_instance_schema(Choosable) + + @classmethod + def get_registered_child_class( + cls: type[SomeBaseClassType], name: str + ) -> "type[SomeBaseClassType]": + """Get a registered class by name that subclasses this class. + + Usage: + >>> class BaseChoosableClass(Choosable): + >>> pass + >>> + >>> class MyChoosableClass(BaseChoosableClass): + >>> pass + >>> + >>> cls_ = BaseChoosableClass.get_registered_child_class("MyChoosableClass") + >>> assert cls_ == MyChoosableClass + + :param name: The name of the class to retrieve + :return: The class that was registered with the given name + :raises UnregisteredChoosableError: If the class was not registered + """ + base_choosable_class = cls._find_base_choosable_class() + registry_for_class = _global_choosable_class_registry[base_choosable_class] + + try: + return cast(type[SomeBaseClassType], registry_for_class[name]) + except KeyError as e: + raise UnregisteredChoosableError( + f"{cls.__name__} with name '{name}' was not registered. Make sure " + f"{cls.__name__} has a parent that inherits from " + f"Choosable so that it's added to the global registry. " + f"Registered classes for '{base_choosable_class.__name__}' are: " + f"{list(registry_for_class.keys())}" + ) from e + + @classmethod + def _find_base_choosable_class( + cls: type[SomeBaseClassType], + ) -> type[SomeBaseClassType]: + """Find the base class that subclasses Choosable. This will be + used as a key for the global registry. + + :return: The class that subclasses Choosable + :raises RuntimeError: If the class doesn't subclass Choosable + """ + # Exit if this cls is already the base inheritor + is_root_class = Choosable in cls.__bases__ + if is_root_class: + return cls + + for base in cls.__bases__: + # Recursively check the base classes of the current base class + if issubclass(base, Choosable): + recursive_base = base._find_base_choosable_class() # noqa: SLF001 + if recursive_base: + return cast(type[SomeBaseClassType], recursive_base) + + raise RuntimeError( + "This shouldn't happen. There should always exist a base class subclassing " + "Choosable. Are you calling this method directly on " + "Choosable? If so, you should inherit it and then call it on" + " a subclass." + ) + + def register_instance(self, name: str) -> None: + """Register this instance with the global registry, with the given name + :param name: The name to register the instance under + :raises DuplicateRegistrationError: if the instance was already registered + """ + base_choosable_class = self._find_base_choosable_class() + registry_for_class = _global_choosable_instance_registry.setdefault( + base_choosable_class, {} + ) + + if name in registry_for_class: + raise DuplicateRegistrationError( + f"An instance with the name {name} has already been registered for " + f"{base_choosable_class.__name__}. Make sure you're not declaring " + f"multiple instances with the same name and same base class." + ) + registry_for_class[name] = self + + @classmethod + def get_registered_instance( + cls: type[SomeBaseClassType], name: str + ) -> SomeBaseClassType: + """Get a registered instance of this class""" + base_choosable_class = cls._find_base_choosable_class() + registry_for_class = _global_choosable_instance_registry[base_choosable_class] + + try: + return cast(SomeBaseClassType, registry_for_class[name]) + except KeyError as ex: + raise UnregisteredChoosableError( + f"{cls.__name__} with name '{name}' was not registered. Make sure " + f"{cls.__name__} is instantiated with the register_instance method " + f"so that it's added to the global registry. " + f"Registered instances for '{base_choosable_class.__name__}' are: " + f"{list(registry_for_class.keys())}" + ) from ex + + @classmethod + def get_registered_class_name(cls: type[SomeBaseClassType]) -> str: + """Get the name of the registered class""" + base_choosable_class = cls._find_base_choosable_class() + registry_for_class = _global_choosable_class_registry[base_choosable_class] + return next(name for name, type_ in registry_for_class.items() if type_ is cls) + + def get_registered_instance_name(self) -> str: + """Get the name of the registered instance""" + base_choosable_class = self._find_base_choosable_class() + registry_for_instance = _global_choosable_instance_registry[ + base_choosable_class + ] + try: + return next( + name + for name, instance in registry_for_instance.items() + if instance is self + ) + except StopIteration as ex: + raise UnregisteredChoosableError( + f"The instance of {self.__class__.__name__} was not registered. Make " + f"sure {self.__class__.__name__} is registered by calling the " + f"register_instance method so that it's added to the global registry. " + f"Registered instances for '{base_choosable_class.__name__}' are: " + f"{list(registry_for_instance.keys())}" + ) from ex diff --git a/pkgs/node_helpers/node_helpers/parameters/loading.py b/pkgs/node_helpers/node_helpers/parameters/loading.py new file mode 100644 index 0000000..dbd67e8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/loading.py @@ -0,0 +1,217 @@ +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import Any, Generic, TypeVar + +import yaml +from pydantic import BaseModel +from rclpy.parameter import Parameter + +MetaParameters = TypeVar("MetaParameters", bound=BaseModel) + + +class Namespace: + """A representation of ROS's period-separated namespacing system. This can be used + to represent a fully qualified node name, like 'manipulator.mock_manipulator', or a + nested parameter name, like 'steppers.x.max_pos'. + """ + + def __init__(self, value: list[str]) -> None: + self.items = value + + @staticmethod + def from_string(value: str) -> "Namespace": + return Namespace(value.split(".")) + + def __str__(self) -> str: + return ".".join(self.items) + + def __repr__(self) -> str: + return f"Namespace({self!s})" + + def __add__(self, other: "Namespace") -> "Namespace": + return Namespace(self.items + other.items) + + def with_item(self, item: str) -> "Namespace": + return Namespace(self.items + [item]) + + def __hash__(self) -> int: + return hash(tuple(self.items)) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Namespace): + return self.items == other.items + raise NotImplementedError + + +class ParameterLoader(Generic[MetaParameters]): + """This class facilitates a workflow for ROS development that expects there to be + two layers of arguments/parameters/configuration that is typically needed for a + project: + + 1) The meta parameters. These are all parameters needed for the launch file + itself, and are not used to feed arguments into nodes themselves. + 2) The node parameters. These are actual ROS parameters that can be fed into + individual nodes. + + The workflow includes loading one or more yaml files and combining them into a + single file, allowing for a layered configuration approach in which common + configuration is placed in a base parameters file, and 'override' files (which are + not typically committed via git) can be created for different hardware environments + or use cases. + """ + + def __init__( + self, + *, + parameters_directory: Path, + override_file: Path | None = None, + meta_parameters_schema: type[MetaParameters] | None = None, + ) -> None: + """ + :param parameters_directory: A directory with one or more yaml files to try + loading. All yaml files will be loaded in alphanumeric order by file name, + and any colliding keys will override the previous entry. + :param override_file: A file to load last, which will override any previous + configuration. + :param meta_parameters_schema: The pydantic schema with which to load the + "meta_parameters" section of the configuration ros_parameters_file. + :raises FileNotFoundError: If no yaml files are found in the provided directory + """ + + files_in_order = [*sorted(parameters_directory.glob("*.yaml"))] + if len(files_in_order) == 0: + raise FileNotFoundError(f"No yaml files found in '{parameters_directory}'!") + + if override_file: + files_in_order.append(override_file) + + yaml_dict: dict[str, Any] = {} + for file_ in filter(Path.is_file, files_in_order): + yaml_dict_to_merge = yaml.full_load(file_.read_text()) + if yaml_dict_to_merge is not None: + yaml_dict = self._merge_dictionaries(yaml_dict, yaml_dict_to_merge) + + # Load the meta parameters, if a schema was specified + self._meta_parameters = None + if meta_parameters_schema: + self._meta_parameters = self._load_meta_params( + yaml_dict, meta_parameters_schema + ) + + # Load the ros node parameters + self.parameters: dict[Namespace, dict[Namespace, Any]] = {} + self._load_file_params(yaml_dict) + + # Save the ros node parameters to a file, for easy parameter injection + yaml_dict.pop(_META_PARAMETERS_KEY, None) + self.ros_parameters_file = self._save_yaml(yaml_dict) + """A path a yaml ros_parameters_file with the attributes combined, containing + the result of merging all provided parameter files together, without the + meta parameters. + """ + + @property + def meta_parameters(self) -> MetaParameters: + if self._meta_parameters is None: + raise RuntimeError( + "You cannot use meta_parameters if a schema was not passed in to the " + "ParameterLoader!" + ) + return self._meta_parameters + + def parameters_for_node(self, namespace: Namespace) -> list[Parameter]: + """ + :param namespace: The namespace and name of the node, like + "manipulator.manipulator_capcom" + :return: The node's parameters, in a format that can be provided to a node's + parameter_overrides argument + """ + node_params = self.parameters.get(namespace, {}) + return [Parameter(name=str(k), value=v) for k, v in node_params.items()] + + @staticmethod + def _merge_dictionaries( + a: dict[str, Any], b: dict[str, Any], path: list[str] | None = None + ) -> dict[str, Any]: + """Merges dictionary b into a and returns the resulting dictionary.""" + a, b = deepcopy(a), deepcopy(b) + + if path is None: + path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + a[key] = ParameterLoader._merge_dictionaries( + a[key], b[key], path + [str(key)] + ) + elif a[key] != b[key]: + a[key] = b[key] + else: + a[key] = b[key] + return a + + def _load_meta_params( + self, data_yaml: dict[str, Any], schema: type[MetaParameters] + ) -> MetaParameters: + if not data_yaml.get(_META_PARAMETERS_KEY, False): + raise ParameterLoadingError( + "A schema was specified for meta parameters, but the key " + f"'{_META_PARAMETERS_KEY}' was not found!" + ) + meta_parameters_dict = data_yaml[_META_PARAMETERS_KEY] + return schema.model_validate(meta_parameters_dict) + + def _load_file_params(self, data_yaml: dict[str, Any]) -> None: + for key, fields in data_yaml.items(): + try: + self._load_namespace_params(fields, Namespace([key])) + except ParameterLoadingError as ex: + raise ParameterLoadingError(f"In field {key}: {ex}") from ex + + def _load_namespace_params( + self, data_yaml: dict[str, Any], node_name: Namespace + ) -> None: + for key, value in data_yaml.items(): + if isinstance(value, dict): + if key == _PARAMETERS_KEY: + self._load_node_params(value, node_name, Namespace([])) + else: + self._load_namespace_params(value, node_name.with_item(key)) + + def _load_node_params( + self, input_: dict[str, Any], node_name: Namespace, param_name: Namespace + ) -> None: + for key, value in input_.items(): + if isinstance(value, dict): + self._load_node_params(value, node_name, param_name.with_item(key)) + else: + if node_name not in self.parameters: + self.parameters[node_name] = {} + self.parameters[node_name][param_name.with_item(key)] = value + + def _save_yaml(self, yaml_dict: dict[str, Any]) -> Path: + # Unfortunately this is the only way to make the yaml writer not create weird + # files full of yaml aliases that ROS can't parse. + yaml.Dumper.ignore_aliases = lambda *args: True # type: ignore + _YAML_FILE.write_text(yaml.dump(yaml_dict)) + return _YAML_FILE + + +class ParameterLoadingError(Exception): + """The parameters file has a formatting issue""" + + +FIELD_PLACEHOLDER = "" +"""The placeholder text put on parameters that need to be filled in manually""" + +_PARAMETERS_KEY = "ros__parameters" +"""ROS's magic key that signals when node namespacing ends and parameters begin""" + +_META_PARAMETERS_KEY = "meta_parameters" +"""The name of the node at the root of the ros_parameters_file that holds the meta +parameters for the launch file. """ + +_YAML_FILE = Path(tempfile.gettempdir()) / "parameters.yaml" +"""The temporary ros_parameters_file location to write the final merged + ros_parameters_file to""" diff --git a/pkgs/node_helpers/node_helpers/parameters/parameter_mixin.py b/pkgs/node_helpers/node_helpers/parameters/parameter_mixin.py new file mode 100644 index 0000000..ffa9a84 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parameter_mixin.py @@ -0,0 +1,307 @@ +import logging +import typing +from types import GenericAlias, NoneType, UnionType +from typing import Any, TypeVar, cast + +from pydantic import BaseModel +from pydantic_core import PydanticUndefined +from rcl_interfaces.msg import ParameterDescriptor, SetParametersResult, ParameterValue +from rclpy.exceptions import ParameterAlreadyDeclaredException, ParameterException +from rclpy.node import Node +from rclpy.parameter import Parameter + +from .parsing import ParsableType, get_parser + +PydanticType = TypeVar("PydanticType", bound=BaseModel) + + +class RequiredParameterNotSetException(Exception): + """Raised when a parmeter was marked as required, but was not set. This can occur + if the user hasn't set the parameter via the launch file or command line.""" + + +class ParameterMixin: + """This mixin enables declaring parameters within a node in a more sane manner + than what rclpy allows by default. + + The benefits of using this mixin is: + 1) Less boilerplate + 2) Support for parsing parameters into Pydantic models! + 3) Type checking! + 4) Validate that the parameter is _actually_ set from outside source, when + `required=True` is set. + + """ + + def declare_and_get_parameter( + self: Node, + name: str, + type_: type[ParsableType], + description: str = "", + default_value: ParsableType | None = None, + required: bool = False, + ) -> ParsableType: + """ + Declare a parameter and immediately receive the (current) value of the param, + where the main use case is to declare a parameter that will be filled in the + launch file. + + :param name: The name of the parameter. + :param type_: What type should be returned by this function. Must have a + supported parser! + :param description: The ROS parameter description. + :param default_value: The default to set this parameter to. Must be None if + required=True. + :param required: If True, then this function will raise an error if the value + was not assigned from an outside source. Use this for values that don't make + sense to have defaults for, but that are also necessary for normal usage. + + :raises RequiredParameterNotSetException: If the parameter is marked as + `required` but is not set by the launch file, this will be raised. + :raises TypeError: If a given value doesn't match the specified type + :raises ValueError: If the parameters to the function were invalid + :raises UnfilledParametersFileError: If an unfilled sentinel value is present + in one of the parameters + + :return: The current value of the parameter + """ + + parser = get_parser(type_) + + if default_value is not None: + # Convert default_value to the 'config' type + default_value_config = parser.as_config_type(default_value) + + if required: + raise ValueError( + "If required=True, then default_value must be None. Got " + f"{default_value=}" + ) + + # Validate the default_value matches the specified type + if not required and not parser.ros_type.check(default_value_config): + raise TypeError( + f"The default value {default_value} for parameter '{name}' does not" + f" match the specified type {parser.ros_type}" + ) + else: + default_value_config = None + + # Declare the parameter and retrieve the current value + descriptor = ParameterDescriptor( + description=description, type=parser.ros_type.value, dynamic_typing=True + ) + + retrieved_value = self.declare_parameter( + name=name, + value=default_value_config, + descriptor=descriptor, + ignore_override=False, + ).value + + if required and retrieved_value is None: + raise RequiredParameterNotSetException( + f"The parameter '{name}' is required to be set externally. Please " + f"specify this parameter in the launch or configuration file. " + f"Namespace: {self.get_namespace()}" + ) + + if ( + isinstance(retrieved_value, str) + and FIELD_PLACEHOLDER in retrieved_value.lower() + ): + raise UnfilledParametersFileError( + f"Field '{name!s}' has not been filled in! Please fill in any fields" + f" with the value '{FIELD_PLACEHOLDER}' with an actual value." + ) + + if not parser.ros_type.check(retrieved_value): + raise TypeError( + f"The parameter '{name}' was set with an incorrect type. The specified " + f"type was {parser.ros_type}, but the final value had a type of " + f"{type(retrieved_value)} and value of '{retrieved_value!s}'" + ) + + return parser.from_config_type(retrieved_value) + + def subscribe_attribute_to_updates( + self: Node, + attr: str, + parameter_name: str, + type_: type[ParsableType], + object_with_attr: object | None = None, + ) -> None: + """Subscribe an attribute on an object to updates for when a parameter gets + changed. + + :param attr: The attribute to update on the object + :param parameter_name: The parameter name to use for updates + :param type_: What type will the updated value be parsed as + :param object_with_attr: The object that has the attribute. If none is set, + it is assumed to be `self`. + :raises AttributeError: If the `object_with_attr` does not in fact have the + specified `attr`. + """ + + object_with_attr = object_with_attr or self + pretty_attribute_name = f"{object_with_attr.__class__.__name__}.{attr}" + + if not hasattr(object_with_attr, attr): + raise AttributeError(f"The attribute {pretty_attribute_name} was not found") + + def on_set_parameter(parameters: list[Parameter]) -> SetParametersResult: + parser = get_parser(type_) + for parameter in parameters: + if parameter.name != parameter_name: + continue + + logging.info( + f"Setting {pretty_attribute_name}={parameter.value}. Relevant " + f"parameter name='{parameter_name}'" + ) + + parsed_value = parser.from_config_type(parameter.value) + setattr(object_with_attr, attr, parsed_value) + break + + return SetParametersResult(successful=True) + + self.add_on_set_parameters_callback(callback=on_set_parameter) + + def declare_from_pydantic_model( + self, model: type[PydanticType], prefix: str, subscribe_to_updates: bool = True + ) -> PydanticType: + """ + :param model: The pydantic model to instantiate from + :param prefix: The parameter name prefix. For example, if model.cool_attribute + exists and the prefix is "stepper_1", then in configuration cool_attribute + would be set as + + stepper_1: + cool_attribute: value + :param subscribe_to_updates: When True, all attributes of the pydantic model + will be capable of being dynamically updates using Node.set_parameters() + :raises ValueError: If the prefix is an empty string + :return: An instantiated model + """ + if prefix == "": + raise ValueError("Empty prefixes are not allowed for pydantic parameters!") + + model_kwargs = {} + subscribable_attributes: list[tuple[str, type, str]] = [] + + for attr_name, field in model.model_fields.items(): + type_ = cast(type[Any], field.annotation) + parameter_name = f"{prefix}.{attr_name}" + + # Generic aliases like list[int] don't work with issubclass, so we need + # this special case + is_union = isinstance(type_, UnionType) + try_types = [type_] if not is_union else typing.get_args(type_) + is_pydantic_model, parsed_value, parsed_type = self._try_to_parse_types( + parameter_name=parameter_name, + possible_types=try_types, + default_value=field.get_default(), + required=field.is_required(), + description=field.description or "", + subscribe_to_updates=subscribe_to_updates, + ) + model_kwargs[attr_name] = parsed_value + + if not is_pydantic_model: + subscribable_attributes.append((attr_name, parsed_type, parameter_name)) + + instantiated: PydanticType = model(**model_kwargs) + + # Subscribe the instantiated object to parameter updates + if subscribe_to_updates: + for attr_name, parseable_type, parameter_name in subscribable_attributes: + self.subscribe_attribute_to_updates( + attr=attr_name, + parameter_name=parameter_name, + type_=parseable_type, + object_with_attr=instantiated, + ) + return instantiated + + def _try_to_parse_types( + self, + parameter_name: str, + possible_types: typing.Sequence[type], + default_value: Any, + required: bool, + subscribe_to_updates: bool, + description: str = "", + ) -> tuple[bool, Any, type]: + """Parse the type and return (is_pydantic_model, parsed_value, parsed_type)""" + + last_exception: BaseException | None = None + + if None in possible_types and possible_types[-1] is not None: + raise ValueError( + "None must be the last element in any pydantic model with Union types" + ) + + # 'declare_and_get_parameter' cannot use PydanticUndefined as a default value, + # so we use None and check at the end if the default value was 'None' + default = None if default_value is PydanticUndefined else default_value + for try_type in possible_types: + is_list = isinstance(try_type, GenericAlias) + is_basemodel = False if is_list else issubclass(try_type, BaseModel) + + try: + if is_basemodel: + parsed_value: Any = self.declare_from_pydantic_model( + model=try_type, + prefix=parameter_name, + subscribe_to_updates=subscribe_to_updates, + ) + return True, parsed_value, try_type + else: + parsed_value = self.declare_and_get_parameter( + name=parameter_name, + type_=try_type, + description=description, + default_value=default, + required=required, + ) + return False, parsed_value, try_type + + except ParameterAlreadyDeclaredException: + # This is a special case where a pydantic model has None set as the + # default value. In this case, we make sure that None was 'try_type', + # and just return None as the parsed value + if default_value is None and try_type is NoneType: + logging.debug( + f"Parameter '{parameter_name}' was already declared, but the " + f"default value was None. Using None as the parsed value." + ) + return False, None, try_type + raise + except ( + ParameterException, + TypeError, + RequiredParameterNotSetException, + ) as e: + # This exception should never happen for non-pydantic models + if not is_basemodel and e is RequiredParameterNotSetException: + raise + + # This is a common case where the parameter didn't match the type. + # We will continue trying to parse the parameter with the next type + logging.debug( + f"Couldn't parse {parameter_name} using type {try_type}, out of " + f"possible types {possible_types}. Reason {type(e)}('{e}')" + ) + last_exception = e + + assert last_exception is not None + raise last_exception + + +FIELD_PLACEHOLDER = "" +"""The placeholder text put on parameters that need to be filled in manually""" + + +class UnfilledParametersFileError(Exception): + """One or more fields were left unfilled in the resulting parameters file""" diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/__init__.py b/pkgs/node_helpers/node_helpers/parameters/parsing/__init__.py new file mode 100644 index 0000000..dcf19fd --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/__init__.py @@ -0,0 +1,50 @@ +from typing import Any + +from .base_parser import BaseParser +from .choosable_class import ChoosableClassParser +from .choosable_instance import ChoosableInstanceParser +from .enum import EnumParser +from .none import NoneParser +from .passthrough import PassthroughParser +from .path import PathParser +from .utils import CONFIG_TO_ROS_MAPPING, ConfigurationType, ParsableType + + +def get_parser(for_parseable_type: type[ParsableType]) -> BaseParser[Any, ParsableType]: + """Map all the supported declarable types to their respective parsers + + :param for_parseable_type: Either a ROS Parameter type (e.g. int, float, str) or a + custom type with a custom parser (e.g. a custom Enum, a Path, etc). + :return: The parser that can convert between the python ros type and the custom type + :raises RuntimeError: If the requested type isn't yet supported. + """ + + parsers: list[BaseParser[Any, Any]] = [ + # These are custom types, where the final type is not equivalent to the + # type written in yaml / returned by ROS. + NoneParser(), + PathParser(), + EnumParser(for_parseable_type), + ChoosableClassParser(for_parseable_type), # type: ignore + ChoosableInstanceParser(for_parseable_type), + # These are passthrough types, where the ROS configuration type is equal to the + # final "parsed" type + PassthroughParser(bool, bool), + PassthroughParser(int, int), + PassthroughParser(float, float), + PassthroughParser(str, str), + PassthroughParser(list[bytes], list[bytes]), + PassthroughParser(list[bool], list[bool]), + PassthroughParser(list[int], list[int]), + PassthroughParser(list[float], list[float]), + PassthroughParser(list[str], list[str]), + ] + + parser = next((p for p in parsers if p.can_parse(for_parseable_type)), None) + if parser is None: + raise RuntimeError( + f"The type '{for_parseable_type}' is not a supported parameter " + f"type! If you'd like to use it, add a custom Parser for it and register it" + " here in this function." + ) + return parser diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/base_parser.py b/pkgs/node_helpers/node_helpers/parameters/parsing/base_parser.py new file mode 100644 index 0000000..ce97847 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/base_parser.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from rclpy.parameter import Parameter + +from .utils import CONFIG_TO_ROS_MAPPING, ConfigurationType + +ParsedType = TypeVar("ParsedType") +"""This represents the type after the parser has received it from ROS""" + + +class BaseParser(ABC, Generic[ConfigurationType, ParsedType]): + def __init__( + self, config_type: type[ConfigurationType], parsed_type: type[ParsedType] + ): + """ + :param config_type: The type as seen in a yaml file, + :param parsed_type: The type after parsing. For example, an enum, a Path, etc. + """ + self.config_type: type[ConfigurationType] = config_type + self.parsed_type: type[ParsedType] = parsed_type + + @property + def ros_type(self) -> Parameter.Type: + """Return the ROS enum of the configuration type for this parser""" + return CONFIG_TO_ROS_MAPPING[self.config_type] + + def can_parse(self, type_: type) -> bool: + """This method can be overridden to allow a parser to take on a broader scope""" + return type_ == self.parsed_type + + @abstractmethod + def from_config_type(self, config_value: ConfigurationType) -> ParsedType: + pass + + @abstractmethod + def as_config_type(self, parsed_value: ParsedType) -> ConfigurationType: + pass diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_class.py b/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_class.py new file mode 100644 index 0000000..3c4d3be --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_class.py @@ -0,0 +1,63 @@ +"""This module needs a type hinting expert to help remove the type: ignore comments.""" + +import typing +from typing import Any, TypeVar + +from node_helpers.parameters.choosable_object import Choosable + +from .base_parser import BaseParser + +ChoosableType = TypeVar("ChoosableType", bound=Choosable) + + +class ChoosableClassParser(BaseParser[str, type[Choosable]]): + def __init__(self, for_chooseable_type: type[ChoosableType]): + """A parser for a class that subclasses Choosable. + :param for_chooseable_type: The type of the class that this parser will parse. + + For example, + >>> ChoosableClassParser(for_choosable_type=type[SomeBaseChoosableClass]) + """ + super().__init__(str, for_chooseable_type) # type: ignore + + def can_parse(self, type_: type[Any]) -> bool: + """ChoosableTypes are intended to be specified as 'type[ChoosableClass]' in + configuration. For that reason, type_ is always a type, and in order to figure + out if the inside of type[THING] is a Choosable Type, we need to check if + type_[0] is a subclass of Choosable. + + :param type_: The type to validate as a ChoosableType + :return: True if type_ is a Choosable, False otherwise + """ + try: + unpacked_type: Any = self._unpack_type_from_type_wrapper(type_) + except ValueError: + return False + + return issubclass(unpacked_type, Choosable) + + def from_config_type(self, config_value: str) -> type[ChoosableType]: + unpacked = self._unpack_type_from_type_wrapper(self.parsed_type) # type: ignore + return unpacked.get_registered_child_class(config_value) # type: ignore + + def as_config_type(self, parsed_value: type[ChoosableType]) -> str: + return parsed_value.get_registered_class_name() + + def _unpack_type_from_type_wrapper( + self, type_: type[ChoosableType] + ) -> type[ChoosableType]: + """Unpack the choosable type type from a type wrapper. + For example, if type_ is type[CoolChoosableClass], return CoolChoosableClass. + + :param type_: The type to unpack + :return: The unpacked type, or None if it was not unpackable. + :raises ValueError: If the type was not a type wrapper holding 1 type + """ + if typing.get_origin(type_) is not type: + raise ValueError + + type_args = typing.get_args(type_) + if len(type_args) != 1: + raise ValueError + + return type_args[0] # type: ignore diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_instance.py b/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_instance.py new file mode 100644 index 0000000..9770f77 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/choosable_instance.py @@ -0,0 +1,29 @@ +from typing import Any, TypeVar + +from node_helpers.parameters.choosable_object import Choosable + +from .base_parser import BaseParser + +ChoosableInstance = TypeVar("ChoosableInstance", bound=Choosable) + + +class ChoosableInstanceParser(BaseParser[str, ChoosableInstance]): + def __init__(self, for_chooseable_instance: type[ChoosableInstance]): + """A parser for retrieving an instance of a class that subclasses Choosable + + For example, + >>> ChoosableInstanceParser(for_chooseable_instance=SomeBaseChoosableClass) + + :param for_chooseable_instance: The type of the class that this parser will + parse. It is _not_ wrapped in a type[], unlike the ChoosableClassParser. + """ + super().__init__(str, for_chooseable_instance) + + def can_parse(self, type_: type[Any]) -> bool: + return issubclass(type_, Choosable) + + def from_config_type(self, config_value: str) -> ChoosableInstance: + return self.parsed_type.get_registered_instance(config_value) + + def as_config_type(self, parsed_value: ChoosableInstance) -> str: + return parsed_value.get_registered_instance_name() diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/enum.py b/pkgs/node_helpers/node_helpers/parameters/parsing/enum.py new file mode 100644 index 0000000..5c98ab5 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/enum.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Any, TypeVar + +from .base_parser import BaseParser + +EnumType = TypeVar("EnumType", bound=Enum) + + +class EnumParser(BaseParser[str, EnumType]): + def __init__(self, enum_type: type[EnumType]): + super().__init__(str, enum_type) + + def can_parse(self, type_: Any) -> bool: + return issubclass(type_, Enum) + + def from_config_type(self, config_value: str) -> EnumType: + return self.parsed_type(config_value) + + def as_config_type(self, parsed_value: EnumType) -> str: + return self.config_type(parsed_value.value) diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/none.py b/pkgs/node_helpers/node_helpers/parameters/parsing/none.py new file mode 100644 index 0000000..d1933d3 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/none.py @@ -0,0 +1,17 @@ +from types import NoneType + +from .base_parser import BaseParser + + +class NoneParser(BaseParser[None, None]): # type: ignore + def __init__(self) -> None: + super().__init__(NoneType, type[None]) # type: ignore + + def can_parse(self, type_: type) -> bool: + return type_ is NoneType + + def from_config_type(self, config_value: None) -> None: + return None + + def as_config_type(self, parsed_value: None) -> None: + return None diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/passthrough.py b/pkgs/node_helpers/node_helpers/parameters/parsing/passthrough.py new file mode 100644 index 0000000..29a373d --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/passthrough.py @@ -0,0 +1,14 @@ +from typing import Generic + +from .base_parser import BaseParser +from .utils import ConfigurationType + + +class PassthroughParser( + BaseParser[ConfigurationType, ConfigurationType], Generic[ConfigurationType] +): + def from_config_type(self, config_value: ConfigurationType) -> ConfigurationType: + return config_value + + def as_config_type(self, parsed_value: ConfigurationType) -> ConfigurationType: + return parsed_value diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/path.py b/pkgs/node_helpers/node_helpers/parameters/parsing/path.py new file mode 100644 index 0000000..dde0cc6 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/path.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from .base_parser import BaseParser + + +class PathParser(BaseParser[str, Path]): + def __init__(self) -> None: + super().__init__(str, Path) + + def can_parse(self, type_: type) -> bool: + return issubclass(type_, Path) + + def from_config_type(self, config_value: str) -> Path: + return self.parsed_type(config_value) + + def as_config_type(self, parsed_value: Path) -> str: + return self.config_type(parsed_value) diff --git a/pkgs/node_helpers/node_helpers/parameters/parsing/utils.py b/pkgs/node_helpers/node_helpers/parameters/parsing/utils.py new file mode 100644 index 0000000..716a1c6 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/parsing/utils.py @@ -0,0 +1,36 @@ +from types import NoneType +from typing import TypeVar + +from rclpy.parameter import Parameter + +CONFIG_TO_ROS_MAPPING = { + NoneType: Parameter.Type.NOT_SET, + bool: Parameter.Type.BOOL, + int: Parameter.Type.INTEGER, + float: Parameter.Type.DOUBLE, + str: Parameter.Type.STRING, + list[bytes]: Parameter.Type.BYTE_ARRAY, + list[bool]: Parameter.Type.BOOL_ARRAY, + list[int]: Parameter.Type.INTEGER_ARRAY, + list[float]: Parameter.Type.DOUBLE_ARRAY, + list[str]: Parameter.Type.STRING_ARRAY, +} +"""This is a nice wrapper to allow python types to map to the ROS parameter API types""" + +ConfigurationType = TypeVar( + "ConfigurationType", + type[None], + bool, + int, + float, + str, + bytes, + list[bool], + list[bytes], + list[int], + list[float], + list[str], +) +"""All supported ROS parameter types""" + +ParsableType = TypeVar("ParsableType") diff --git a/pkgs/node_helpers/node_helpers/parameters/path.py b/pkgs/node_helpers/node_helpers/parameters/path.py new file mode 100644 index 0000000..468bbc7 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/parameters/path.py @@ -0,0 +1,26 @@ +from pathlib import Path + + +def param_path(path_str: str | Path) -> Path: + """Interprets the provided path string as relative to the root of the repository if + it is a relative path. + + :param path_str: The path provided as a ROS parameter + :raises RuntimeError: If the code is placed in some strange location + :return: An absolute path + """ + + path = Path(path_str) + if path.is_absolute(): + return path + + # Find the root of the repository + root_candidate = Path.cwd() + while not (root_candidate / "pkgs").is_dir(): + if len(root_candidate.parents) == 0: + raise RuntimeError( + f"Could not find root of repository! Started at '{Path.cwd()}'." + ) + root_candidate = root_candidate.parent + + return root_candidate / path diff --git a/pkgs/node_helpers/node_helpers/pubsub/README.md b/pkgs/node_helpers/node_helpers/pubsub/README.md new file mode 100644 index 0000000..c1aa111 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/pubsub/README.md @@ -0,0 +1,45 @@ +# node_helpers.pubsub + +This module provides an in-process publish-subscribe (pub-sub) system akin to a signals/slots mechanism. While ROS itself is a pub-sub system for inter-process communication using topics, this module is designed for intra-process communication, enabling lightweight, event-driven interactions between components within the same process. + +This approach is ideal for scenarios where you need responsive communication without the overhead of ROS's networked messaging system. + + +Here's where one might use a topic: + +```python3 +class ExampleSensorBuffer(ABC, Generic[SENSOR_MSG]): + """Some sensor buffer""" + + def __init__(self, ...): + super().__init__() + # Topics + self.on_value_change = Topic[SENSOR_MSG]() + """Called when a reading has changed from a previous value""" + self.on_receive = Topic[SENSOR_MSG]() + """Called whenever a new reading is received""" +``` + +This could then be used by other components like so: + +```python3 + +class ExampleSensor: + """Some sensor""" + + def __init__(self, buffer: ExampleSensorBuffer): + super().__init__() + self.buffer = buffer + self.buffer.on_value_change.subscribe(self.on_value_change) + self.buffer.on_receive.subscribe(self.on_receive) + + def on_value_change(self, msg: SensorMsg): + print(f"Value changed: {msg}") + # Do something with the new value + + def on_receive(self, msg: SensorMsg): + print(f"Received: {msg}") + # Do something with the new value +``` + +The topics could be for any use, not just for sensors. They could be for any kind of event-driven communication within a process. \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/pubsub/__init__.py b/pkgs/node_helpers/node_helpers/pubsub/__init__.py new file mode 100644 index 0000000..27c4a8c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/pubsub/__init__.py @@ -0,0 +1,7 @@ +from .event import PublishEvent +from .topic import ( + DuplicateSubscriberError, + SubscriberNotFoundError, + Topic, + multi_subscribe_as_event, +) diff --git a/pkgs/node_helpers/node_helpers/pubsub/event.py b/pkgs/node_helpers/node_helpers/pubsub/event.py new file mode 100644 index 0000000..c1be1fd --- /dev/null +++ b/pkgs/node_helpers/node_helpers/pubsub/event.py @@ -0,0 +1,48 @@ +from threading import Event +from typing import Any + + +class PublishEvent(Event): + """An event with a method that can throw away arguments given to it by a + publisher, and therefore can be used as a subscriber in + ``Topic.subscribe_as_event``. + + Why use this instead of a simple function or lambda that throws away + arguments and calls ``event.set()`` instead? Using this method ensures that + the subscription is garbage collected once this event object falls out of + scope. If we used a weakref to a function or lambda instead, the function + would fall out of scope immediately after calling + ``Topic.subscribe_as_event`` and get garbage collected. If it was a + regular reference to a function or lambda, it would never fall out of scope + and callers would need to remember to unsubscribe the event. + """ + + def set_ignore_args(self, *_args: Any, **_kwargs: Any) -> None: + self.set() + + def wait_and_clear(self, timeout: float | None = None) -> bool: + """This helps reinforce a pattern of 'waiting' then 'clearing' an event. + + Here's an example of waiting on an event with a bug: + >>> while a_value_changed_event.wait(): + >>> if a_value: + >>> break + + In this example, the first time the value changes it will set the event to True, + and now you have a busy loop on your hands. Here's the fixed version: + >>> while a_value_changed_event.wait_and_clear(): + >>> if a_value: + >>> break + + Now, it will wait on the event, clear the event condition, and check the value. + Next time it hits the 'while' line, it will wait again! + + :param timeout: The timeout in seconds for the event.wait() call + :return: True, if the condition was set + """ + was_set = self.wait(timeout=timeout) + if was_set: + # Only run clear if the event was set, avoiding a race condition where + # was_set is False, but it becomes set and is immediately cleared after. + self.clear() + return was_set diff --git a/pkgs/node_helpers/node_helpers/pubsub/topic.py b/pkgs/node_helpers/node_helpers/pubsub/topic.py new file mode 100644 index 0000000..3b4d20c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/pubsub/topic.py @@ -0,0 +1,145 @@ +from collections.abc import Callable +from threading import RLock +from typing import Generic, TypeVar +from weakref import ReferenceType, WeakMethod, ref + +from .event import PublishEvent + +EMITS = TypeVar("EMITS") +"""The value that gets passed into callbacks""" + +SUBSCRIBER = Callable[[EMITS], None] +"""A function that can be called when a topic has a new value""" + + +class SubscriberNotFoundError(Exception): + pass + + +class DuplicateSubscriberError(Exception): + pass + + +class Topic(Generic[EMITS]): + """A topic that can be subscribed to and published to.""" + + def __init__(self) -> None: + self._subscribers: list[ReferenceType[SUBSCRIBER[EMITS]]] = [] + self._subscribers_lock = RLock() + + def subscribe(self, subscriber: SUBSCRIBER[EMITS]) -> None: + """Adds a new subscriber to the topic. + + Subscribers will be notified of updates in the same thread as the + publisher, so subscription functions should be thread-safe and should + avoid doing a lot of work. + + :param subscriber: A function that will be called when the topic is emitted to. + + :raises DuplicateSubscriberError: When subscribing a listener twice to the + same resource. + """ + with self._subscribers_lock: + if self.is_subscribed(subscriber): + message = "A subscriber already exists for this topic!" + raise DuplicateSubscriberError(message) + + self._subscribers.append(self._to_weakref(subscriber)) + + def subscribe_as_event(self) -> PublishEvent: + """Adds a new subscriber returns an event object that is set when the topic is + published to. If the publisher provides any arguments, they will be discarded. + + For example: + >>> cool_thing_event = topic.subscribe_as_event() + >>> # Now you can wait blocklingly until the COOL_THING topic is published to + >>> cool_thing_event.wait_and_clear(timeout=3) + + This is useful in situations where the actual value published doesn't matter, + but rather, you need to wait until something has been published. + + :return: An event that is set when the topic is published to + """ + event = PublishEvent() + + with self._subscribers_lock: + self._subscribers.append(self._to_weakref(event.set_ignore_args)) + + return event + + def unsubscribe(self, subscriber: SUBSCRIBER[EMITS]) -> None: + """Unsubscribe a listener from this topic + :param subscriber: The callback to unsubscribe from that topic + """ + with self._subscribers_lock: + # Find the weakref that refers to this subscriber + subscriber_weakref = self._to_registered_weakref(subscriber) + self._subscribers.remove(subscriber_weakref) + + def publish(self, *value: EMITS) -> None: + """Publishes a new value for a topic. + + :param value: The value to pass on to subscribers + """ + with self._subscribers_lock: + for callback_weakref in self._subscribers.copy(): + callback = callback_weakref() + if callback is None: + # This listener has been garbage collected, clean it out + self._subscribers.remove(callback_weakref) + continue + callback(*value) + + def is_subscribed(self, subscriber: SUBSCRIBER[EMITS]) -> bool: + try: + self._to_registered_weakref(subscriber) + except SubscriberNotFoundError: + return False + else: + return True + + def _to_weakref( + self, callback: SUBSCRIBER[EMITS] + ) -> ReferenceType[SUBSCRIBER[EMITS]]: + """Convert a callable to a weakref, supporting bound and unbound methods""" + if hasattr(callback, "__self__") and hasattr(callback, "__func__"): + return WeakMethod(callback) + else: + return ref(callback) + + def _to_registered_weakref( + self, subscriber: SUBSCRIBER[EMITS] + ) -> ReferenceType[SUBSCRIBER[EMITS]]: + """Looks for the weakref for this specifically (already subscribed) + subscriber. If it's not found, it raises an exception. + + :param subscriber: The callback to store + :return: The weakref to the subscriber callback + :raises SubscriberNotFoundError: When a subscriber doesn't exist for this + particular topic. + """ + subscriber_weakref = self._to_weakref(subscriber) + try: + with self._subscribers_lock: + weakref = next(s for s in self._subscribers if s == subscriber_weakref) + except StopIteration as ex: + raise SubscriberNotFoundError( + f"Could not find subscriber {subscriber}" + f"in {self.__class__.__name__}" + ) from ex + return weakref + + +def multi_subscribe_as_event(*topics: Topic[EMITS]) -> PublishEvent: + """Subscribes to multiple topics and returns an event that is set when any of + the topics are published to. + + :param topics: The topics to subscribe to + :return: An event that is set when any of the topics are published to + """ + event = PublishEvent() + + for topic in topics: + topic.subscribe(event.set_ignore_args) + + return event diff --git a/pkgs/node_helpers/node_helpers/qos/__init__.py b/pkgs/node_helpers/node_helpers/qos/__init__.py new file mode 100644 index 0000000..b35fb6f --- /dev/null +++ b/pkgs/node_helpers/node_helpers/qos/__init__.py @@ -0,0 +1,40 @@ +"""This module contains common QOS constants used across the stack. + +Make sure to document the intended use case for each QOS, so that they can easily be +reused. +""" + +from rclpy.qos import ( + DurabilityPolicy, + HistoryPolicy, + QoSDurabilityPolicy, + QoSProfile, + ReliabilityPolicy, +) + +qos_profile_transient_reliable_keep_all = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_ALL, +) +"""This is used for subscribing/publishing to a topic where receiving all messages ever +published is important. + +Keep in mind this should be used for transferring occasional messages, where the number +of messages will not go up to infinity as time goes on. If it does, you will eventually +run out of either storage or RAM. +""" + +qos_latching = QoSProfile( + depth=1, + durability=QoSDurabilityPolicy.TRANSIENT_LOCAL, +) +"""An approximation of latching from ROS1. Makes the latest published message available +always. +""" + +qos_reliable_latest_msg = QoSProfile( + depth=1, + reliability=ReliabilityPolicy.RELIABLE, + history=HistoryPolicy.KEEP_LAST, +) diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/README.md b/pkgs/node_helpers/node_helpers/robust_rpc/README.md new file mode 100644 index 0000000..f7c869d --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/README.md @@ -0,0 +1,9 @@ +# node_helpers.robust_rpc + +The RobustRPC framework is one of the key components of ``node_helpers``. Its key API +is the ``RobustRPCMixin``, which provides a robust approach to handling errors in service +and action calls by propagating error messages raised by the server and re-raising them +on the client side. +This documentation aims to help users understand and effectively use the RobustRPCMixin. + +The Robust RPC framework is documented in [docs/](../../../../docs/robust_rpc.rst) \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/__init__.py b/pkgs/node_helpers/node_helpers/robust_rpc/__init__.py new file mode 100644 index 0000000..95f8c07 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/__init__.py @@ -0,0 +1,22 @@ +from .action_client import RobustActionClient +from .errors import ( + ActionCancellationRejected, + ExecutorNotSetError, + InvalidRobustMessage, + RobustRPCException, +) +from .mixin import RobustRPCMixin +from .schema import ERROR_DESCRIPTION_FIELD, ERROR_NAME_FIELD +from .service_client import RobustServiceClient + +__all__ = [ + "RobustActionClient", + "ActionCancellationRejected", + "ExecutorNotSetError", + "InvalidRobustMessage", + "RobustRPCException", + "RobustRPCMixin", + "ERROR_DESCRIPTION_FIELD", + "ERROR_NAME_FIELD", + "RobustServiceClient", +] diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/_readiness.py b/pkgs/node_helpers/node_helpers/robust_rpc/_readiness.py new file mode 100644 index 0000000..e878087 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/_readiness.py @@ -0,0 +1,52 @@ +from collections.abc import Callable +from time import time + +from rclpy.node import Node + +from .errors import ExecutorNotSetError + + +class ValidatesReadinessMixin: + _node: Node + """A reference to the node that handles this system""" + + _has_been_called: bool = False + """This returns true once the server has been called and validated to be ready""" + + def _validate_rpc_server_is_ready( + self, wait_fn: Callable[[], bool], rpc_name: str + ) -> None: + """This function solves the problem of making sure an action server is ready + and callable by a client before calling it. + + There are two situations where a server might not be callable: + 1) The server is not 'ready'. In these cases, the call will hang forever. + 2) The client side doesn't have an executor set. E.g., rclpy.spin() is not + yet running, which is common when in a Node.__init__(). In this case, an RPC + call will also hang forever. + + For case 1, this function will block indefinitely and periodically print logs. + For case 2, this function will raise an ExecutorNotSet exception with an + explanation. + + :param wait_fn: A function that will wait some time return true/false if + the service or action is ready. + :param rpc_name: The name of the service or action being used. + :raises ExecutorNotSetError: When called in a nodes __init__. + """ + + if not self._has_been_called: + if self._node.executor is None: + raise ExecutorNotSetError( + f"The RPC '{self}' has been called before a node was assigned an " + f"executor! Did you try running an RPC call in the init of a node? " + f"Try using a single shot timer instead!" + ) + + start_time = time() + while not wait_fn(): + elapsed = time() - start_time + self._node.get_logger().error( + f"RPC server '{rpc_name}' is not ready yet after {elapsed} seconds!" + ) + self._has_been_called = True diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/_wrappers.py b/pkgs/node_helpers/node_helpers/robust_rpc/_wrappers.py new file mode 100644 index 0000000..c4d9f1e --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/_wrappers.py @@ -0,0 +1,109 @@ +import logging +import traceback +from collections.abc import Callable +from typing import Any + +from rclpy import Future +from rclpy.action.server import ServerGoalHandle + +from .errors import RobustRPCException +from .schema import ERROR_DESCRIPTION_FIELD, ERROR_NAME_FIELD +from .typing import RequestType, ResponseType + +ServiceCallback = Callable[[RequestType, ResponseType], ResponseType] +ActionCallback = Callable[[ServerGoalHandle], ResponseType] + + +def patch_future_to_raise_exception(future: Future, parse_as_action: bool) -> None: + """Patches the future.result() method to raise an error if the result from the + service or action has an error set within the resulting message. + :param future: The future to patch the result() method of + :param parse_as_action: If True, it will expect the exception information to be + nested within the "result" attribute, as actions do. Otherwise, it will expect + the exception name and description to be root object attributes. + """ + + def result_patch() -> Any: + if future._exception: # noqa: SLF001 + # If the future failed for other reasons, prioritize that exception + raise future.exception() + + result = future._result # noqa: SLF001 + msg_with_error_info = result.result if parse_as_action else result + msg_error = getattr(msg_with_error_info, ERROR_NAME_FIELD) + + if msg_error != "": + # Since an error occurred, raise the appropriate error type + msg_error_description = getattr( + msg_with_error_info, ERROR_DESCRIPTION_FIELD + ) + error_class = RobustRPCException.like(msg_error) + raise error_class( + error_name=msg_error, + error_description=msg_error_description, + message=result, + ) + + return result + + future.result = result_patch + + +def wrap_service_callback(callback: ServiceCallback, srv_name: str) -> ServiceCallback: + """Wraps a service callback, catches errors, and puts them in service message""" + + def wrapper(request: RequestType, response: ResponseType) -> ResponseType: + try: + return callback(request, response) + except Exception as ex: # noqa: BLE001 + _add_error_info(ex, srv_name, response) + return response + + return wrapper + + +def wrap_action_callback( + callback: ActionCallback, action_name: str, result_type: type[ResponseType] +) -> ActionCallback: + """Wraps an action callback, catches errors, and puts them in the action message.""" + + def wrapper(goal: ServerGoalHandle) -> ResponseType: + try: + return callback(goal) + except Exception as ex: # noqa: BLE001 + goal.abort() + response = result_type() + _add_error_info(ex, action_name, response) + return response + + return wrapper + + +def _add_error_info(ex: Exception, rpc_name: str, result: ResponseType) -> None: + """Adds error information to the provided result if the result type has the + necessary fields + + :param ex: The exception to get information on + :param rpc_name: The name of the service or action that has failed + :param result: The result to provide the information to + """ + if isinstance(ex, RobustRPCException): + # Pass through the pre-existing name and description + name = ex.error_name + description = ex.error_description + else: + name = type(ex).__name__ + description = str(ex) + + logging.error( + f"Exception while running RPC '{rpc_name}':\n{traceback.format_exc()}" + ) + setattr(result, ERROR_NAME_FIELD, name) + setattr(result, ERROR_DESCRIPTION_FIELD, description) + + +__all__ = [ + "wrap_action_callback", + "wrap_service_callback", + "patch_future_to_raise_exception", +] diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/action_client.py b/pkgs/node_helpers/node_helpers/robust_rpc/action_client.py new file mode 100644 index 0000000..c8af976 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/action_client.py @@ -0,0 +1,146 @@ +import contextlib +from collections.abc import Generator +from threading import RLock +from typing import Any, cast + +from action_msgs.msg import GoalStatus +from action_msgs.srv import CancelGoal +from rclpy import Future +from rclpy.action import ActionClient +from rclpy.action.client import ClientGoalHandle + +from ..futures import wait_for_future +from ._readiness import ValidatesReadinessMixin +from ._wrappers import patch_future_to_raise_exception +from .errors import ActionCancellationRejected +from .typing import RobustActionMsg + + +class RobustActionClient(ActionClient, ValidatesReadinessMixin): + """Wraps an ActionClient so that any exceptions raised by the remote server are + raised in the client, when calling send_goal_async().result() or send_goal() + """ + + @contextlib.contextmanager + def send_goal_as_context( + self, + goal: RobustActionMsg, + timeout: float | None = None, + cancel: bool = True, + block_for_result: bool = True, + **kwargs: Any, + ) -> Generator[ClientGoalHandle, None, None]: + """Create a context manager that will send a goal, then after the context cancel + and wait for a result. + + Note: There are no guarantees that by the time the context is entered, that the + action is running yet. + + :param goal: The goal to send + :param timeout: The timeout for the goal to expected, and for the goal to be + cancelled. Helpful for when writing tests. + :param cancel: If True, the action will be cancelled if its not already finished + when the context ends. If the cancellation is rejected, an exception will + be raised. + :param block_for_result: If True, the context manager will block for a result + when exiting + :param kwargs: Any other send_goal_async parameters + :yields: The ClientGoalHandle of the created action + + :raises ActionCancellationRejected: If cancel=True but the server rejects the + cancellation request. + """ + goal_handle = wait_for_future( + self.send_goal_async(goal=goal, **kwargs), ClientGoalHandle, timeout=timeout + ) + + try: + yield goal_handle + finally: + # Cancel the goal if it hasn't finished yet + if cancel and goal_handle.status in [ + GoalStatus.STATUS_UNKNOWN, + GoalStatus.STATUS_ACCEPTED, + GoalStatus.STATUS_EXECUTING, + ]: + cancel_response = wait_for_future( + goal_handle.cancel_goal_async(), + type_=CancelGoal.Response, + timeout=timeout, + ) + + action_aborted = ( + cancel_response.return_code + == CancelGoal.Response.ERROR_GOAL_TERMINATED + ) + + # If the goal was not cancelled for reasons other than a remote error, + # raise an exception. + if len(cancel_response.goals_canceling) == 0 and not action_aborted: + raise ActionCancellationRejected( + f"Action {self._action_name} is required to be cancellable, but" + f" the cancellation was rejected! " + f"Cancellation Code: {cancel_response.return_code}" + ) + + if block_for_result: + # Wait for the goal to finish cancelling + wait_for_future(goal_handle.get_result_async(), object, timeout=timeout) + + def send_goal_async(self, *args: Any, **kwargs: Any) -> Future: + self._validate_rpc_server_is_ready( + wait_fn=lambda: cast(bool, self.wait_for_server(10)), + rpc_name=self._action_name, + ) + + return super().send_goal_async(*args, **kwargs) + + def _get_result_async(self, goal_handle: ClientGoalHandle) -> Future: + """Patch the future so that when result() is called it raises remote errors""" + + future = super()._get_result_async(goal_handle) + patch_future_to_raise_exception(future=future, parse_as_action=True) + return future + + +class PatchRclpyIssue1123(RobustActionClient): + """TODO: Remove this patch after the issue https://github.com/ros2/rclpy/issues/1123 + is resolved. This might be done on ROS2 Jazzy. + + Tracking PR: https://github.com/ros2/rclpy/pull/1308 + + The short summary here: This hacky patch fixes an issue where state within the + ActionClient.*_async() functions had a race condition with what ActionClient.execute + does, resulting in Actions that would never complete. + + Please remove this hack once this issue is resolved upstream. + """ + + _lock: RLock = None # type: ignore + + @property + def _cpp_client_handle_lock(self) -> RLock: + if self._lock is None: + self._lock = RLock() + return self._lock + + async def execute(self, *args: Any, **kwargs: Any) -> None: + # This is ugly- holding on to a lock in an async environment feels gross + with self._cpp_client_handle_lock: + return await super().execute(*args, **kwargs) # type: ignore + + def send_goal_async(self, *args: Any, **kwargs: Any) -> Future: + with self._cpp_client_handle_lock: + return super().send_goal_async(*args, **kwargs) + + def _cancel_goal_async(self, *args: Any, **kwargs: Any) -> Future: + with self._cpp_client_handle_lock: + return super()._cancel_goal_async(*args, **kwargs) + + def _get_result_async(self, *args: Any, **kwargs: Any) -> Future: + with self._cpp_client_handle_lock: + return super()._get_result_async(*args, **kwargs) + + +# Replace the robust action client with its patched variant +RobustActionClient = PatchRclpyIssue1123 # type: ignore diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/action_server.py b/pkgs/node_helpers/node_helpers/robust_rpc/action_server.py new file mode 100644 index 0000000..1b84f7f --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/action_server.py @@ -0,0 +1,34 @@ +import logging +from typing import Any + +from rclpy.action import ActionServer + + +class _PatchedRclpyIssue1236(ActionServer): + """ + TODO: Remove this class once we upgrade from Humble to Jazzy, assuming it's fixed + Tracking issue: https://github.com/ros2/rclpy/issues/1236 + """ + + async def _execute_goal(self, execute_callback: Any, goal_handle: Any) -> Any: + try: + return await super()._execute_goal(execute_callback, goal_handle) + except KeyError: + logging.error( + "Caught KeyError in action server, ignoring it and returning." + f" Happened in function=_execute_goal, " + f"with args {execute_callback=} {goal_handle=}" + ) + + async def _execute_get_result_request(self, request_header_and_message: Any) -> Any: + try: + return await super()._execute_get_result_request(request_header_and_message) + except KeyError: + logging.error( + "Caught KeyError in action server, ignoring it and returning." + f" Happened in function=_execute_get_result_request, " + f"with args {request_header_and_message=}" + ) + + +RobustActionServer = _PatchedRclpyIssue1236 diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/errors.py b/pkgs/node_helpers/node_helpers/robust_rpc/errors.py new file mode 100644 index 0000000..8ef9426 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/errors.py @@ -0,0 +1,81 @@ +from threading import RLock +from typing import Any + + +class InvalidRobustMessage(Exception): + """This exception is raised if a message lacks the required fields for propogating + error information. + """ + + +class ActionCancellationRejected(Exception): + """Raised if it was implied an action needed to be cancelled, but the cancellation + request was rejected. + """ + + +class ExecutorNotSetError(Exception): + """Raised if a service or action attempts to send a request before the nodes + executor has been set. This can happen when running an action in the __init__ of a + node.""" + + +_ROBUST_RPC_EXCEPTION_CACHE_LOCK = RLock() +_ROBUST_RPC_EXCEPTION_CACHE: dict[str, type["RobustRPCException"]] = {} +"""This dict carries a mapping of 'error_name' to a dynamically generated exception +object of the same name. """ + + +class RobustRPCException(Exception): + def __init__(self, error_name: str, error_description: str, message: Any): + super().__init__( + f"RPC Call failed with exception {error_name}('{error_description}')" + ) + self.error_name = error_name + self.error_description = error_description + self.message = message + + @classmethod + def like( + cls: type["RobustRPCException"], exception: str | type[Exception] + ) -> "type[RobustRPCException]": + """The most ergonomic way to catch errors that occured remotely! + + This function returns an error 'like' the one passed in, but that subclasses + a RobustRPCException, and will match the same error object that would be raised + for an error of the same name. + + :param exception: The object or name you want to create a mirroring error object + of. + :return: An object with the same name as 'exception', but that subclasses + RobustRPCException. + + Usage example: + >>> try: + >>> maybe_raises_an_rpc_error() + >>> except RobustRPCException.like(OutOfBoundsError): + >>> # In this case, the 'like' function will take the .__class__.__name__ + >>> # of the error to create a new one that subclasses RobustRPCException + >>> handle_out_of_bounds_case() + >>> except RobustRPCException.like("ErrorFromAPackageThatsNotADependency"): + >>> # In this case, the string is used as the name of the new error object + >>> handle_error() + >>> except RobustRPCException: + >>> # All errors still subclass RobustRPCException and can be caught + >>> handle_all_other_robust_rpc_errors() + + Some useful properties to explain the output error object: + >>> assert RobustRPCException.like(OutOfBoundsError) != OutOfBoundsError + >>> assert RobustRPCException.like(OutOfBoundsError) is not OutOfBoundsError + >>> assert ( + >>> RobustRPCException.like(OutOfBoundsError).__name__ + >>> == OutOfBoundsError.__name__ + >>> ) + """ + ex_name: str = exception if isinstance(exception, str) else exception.__name__ + + with _ROBUST_RPC_EXCEPTION_CACHE_LOCK: + if ex_name not in _ROBUST_RPC_EXCEPTION_CACHE: + _ROBUST_RPC_EXCEPTION_CACHE[ex_name] = type(ex_name, (cls,), {}) + + return _ROBUST_RPC_EXCEPTION_CACHE[ex_name] diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/mixin.py b/pkgs/node_helpers/node_helpers/robust_rpc/mixin.py new file mode 100644 index 0000000..bf69a8e --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/mixin.py @@ -0,0 +1,88 @@ +from typing import Any + +from rclpy.action import ActionServer +from rclpy.node import Node +from rclpy.service import Service + +from ._wrappers import ( + ActionCallback, + ServiceCallback, + wrap_action_callback, + wrap_service_callback, +) +from .action_client import RobustActionClient +from .action_server import RobustActionServer +from .schema import validate_robust_message +from .service_client import RobustServiceClient +from .typing import RobustActionMsg, RobustServiceMsg + + +class RobustRPCMixin: + """This mixin adds methods for creating service/action clients and servers. + + When the client and server is used on both sides, then error messages raised by + the server will be caught, passed in the response message, and then re-raised on the + client side. + + The client will raise an error of the same name that subclasses RobustRPCException, + which has the properties error_name, error_description, and message. The message + holds the message object that caused the exception to be raised. + """ + + def create_robust_client( + self: Node, srv_type: type[RobustServiceMsg], srv_name: str, **kwargs: Any + ) -> RobustServiceClient: + validate_robust_message(srv_type.Response) + client = self.create_client(srv_type=srv_type, srv_name=srv_name, **kwargs) + return RobustServiceClient.from_client(client, node=self) + + def create_robust_service( + self: Node, + srv_type: type[RobustServiceMsg], + srv_name: str, + callback: ServiceCallback, + **kwargs: Any, + ) -> Service: + validate_robust_message(srv_type.Response) + return self.create_service( + srv_type=srv_type, + srv_name=srv_name, + callback=wrap_service_callback(callback=callback, srv_name=srv_name), + **kwargs, + ) + + def create_robust_action_client( + self: Node, + action_type: type[RobustActionMsg], + action_name: str, + **kwargs: Any, + ) -> RobustActionClient: + validate_robust_message(action_type.Result) + return RobustActionClient(self, action_type, action_name, **kwargs) + + def create_robust_action_server( + self: Node, + action_type: type[RobustActionMsg], + action_name: str, + execute_callback: ActionCallback, + result_timeout: float | None = None, + **kwargs: Any, + ) -> ActionServer: + validate_robust_message(action_type.Result) + + return RobustActionServer( + self, + action_type=action_type, + action_name=action_name, + execute_callback=wrap_action_callback( + callback=execute_callback, + action_name=action_name, + result_type=action_type.Result, + ), + # By default, significantly reduce the result_timeout as compared to the + # default rclpy behavior. This is done because the default behavior keeps a + # cache of the last FIFTEEN MINUTES (!) worth of action call results, which + # can lead to significant (60-80%) slowdowns in action call times. + result_timeout=result_timeout or 10, + **kwargs, + ) diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/schema.py b/pkgs/node_helpers/node_helpers/robust_rpc/schema.py new file mode 100644 index 0000000..5c6a1fb --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/schema.py @@ -0,0 +1,22 @@ +from .errors import InvalidRobustMessage +from .typing import ResponseType + +ERROR_NAME_FIELD = "error_name" +ERROR_DESCRIPTION_FIELD = "error_description" + + +def validate_robust_message(message: type[ResponseType]) -> None: + """Validates that a message can be used in the robust RPC framework + + This necessitates having the error name and error description fields. + + :param message: Either a Service or Action type message to check the fields for + :raises InvalidRobustMessage: If any required fields are missing. + """ + has_name_field = hasattr(message, ERROR_NAME_FIELD) + has_description_field = hasattr(message, ERROR_DESCRIPTION_FIELD) + if not has_name_field or not has_description_field: + raise InvalidRobustMessage( + f"The message {message} is missing required fields: " + f"{ERROR_NAME_FIELD}, {ERROR_DESCRIPTION_FIELD}" + ) diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/service_client.py b/pkgs/node_helpers/node_helpers/robust_rpc/service_client.py new file mode 100644 index 0000000..05f7485 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/service_client.py @@ -0,0 +1,47 @@ +from typing import cast + +from rclpy import Future +from rclpy.client import Client +from rclpy.node import Node, SrvTypeRequest + +from ._readiness import ValidatesReadinessMixin +from ._wrappers import patch_future_to_raise_exception + + +class RobustServiceClient(Client, ValidatesReadinessMixin): + """Wraps an rclpy.Client so that any exceptions raised by the remote server are + raised in the client, when calling call_async().result() or call() + """ + + @classmethod + def from_client(cls, client: Client, node: Node) -> "RobustServiceClient": + """Wrap an instantiated client object in-place as a RobustServiceClient. + + Unlike the RobustActionClient, it was impossible to subclass the Client and + patch rclpy to automatically create a RobustServiceClient. In place of that, the + creation of the RobustServiceClient is done by monkeypatching the `__class__` + attribute, in the from_client method. + + :param client: The client to wrap with the RobustServiceClient interface + :param node: The node that is creating this client. Currently this is only used + by the ValidatesReadinessMixin to check if there is an executor present. + :returns: The client, but as a RobustServiceClient + """ + client.__class__ = RobustServiceClient + + # Patch in the 'node' attribute to fulfill the ValidatesReadinessMixin needs + client._node = node # noqa: SLF001 + + return client # type: ignore + + def call_async(self, request: SrvTypeRequest) -> Future: + """Patch the future so that when result() is called it raises remote errors""" + self._validate_rpc_server_is_ready( + wait_fn=lambda: cast(bool, self.wait_for_service(10)), + rpc_name=self.srv_name, + ) + + future = super().call_async(request) + + patch_future_to_raise_exception(future=future, parse_as_action=False) + return future diff --git a/pkgs/node_helpers/node_helpers/robust_rpc/typing.py b/pkgs/node_helpers/node_helpers/robust_rpc/typing.py new file mode 100644 index 0000000..d6b6c6e --- /dev/null +++ b/pkgs/node_helpers/node_helpers/robust_rpc/typing.py @@ -0,0 +1,22 @@ +from typing import Protocol + + +class RequestType(Protocol): + """For type hinting a service_msg.Request or an action_msg.Goal""" + + +class ResponseType(Protocol): + """For type hinting a service_msg.Response or an action_msg.Result""" + + error_name: str + error_description: str + + +class RobustServiceMsg(Protocol): + Request: type[RequestType] + Response: type[ResponseType] + + +class RobustActionMsg(Protocol): + Goal: type[RequestType] + Result: type[ResponseType] diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/LICENSE b/pkgs/node_helpers/node_helpers/ros2_numpy/LICENSE new file mode 100644 index 0000000..cdec63e --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 Eric Wieser + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/README.md b/pkgs/node_helpers/node_helpers/ros2_numpy/README.md new file mode 100644 index 0000000..2da527b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/README.md @@ -0,0 +1,105 @@ +# Urban Machine Note + +This is a hard fork of [ros2_numpy][ros2_numpy] maintained inside the `node_helpers` +repository. At the time of writing, this fork was made at upstream commit ID +``780ef66b7e800bf278c6f72b82f821cc542eed7d``. It was last updated to have feature parity +at 2024-11-22. + +It is considered to be essential for use in nearly any python-based robotics stack, so +it is maintained here to ensure that it is available to all robot packages. + +We also will contribute upstream changes to the original repository as needed. However, +because the original ros2_numpy repository is not available on the ROS2 packages index, +we decided to maintain this fork inside of node_helpers. + +The original README is included below. + +--- + +# Original README + +[ros2_numpy]: https://github.com/Box-Robotics/ros2_numpy + +This project is a fork of [ros_numpy](https://github.com/eric-wieser/ros_numpy) +to work with ROS 2. It provides tools for converting ROS messages to and from +numpy arrays. In the ROS 2 port, the module has been renamed to +`ros2_numpy`. Users are encouraged to update their application code to import +the module as shown below. + +ROS 2: + +``` +import ros2_numpy as rnp +``` + +ROS 1: + +``` +import ros_numpy as rnp +``` + +Prefacing your calls like `rnp.numpify(...)` or `rnp.msgify(...)` should help +future proof your codebase while the ROS 1 and ROS 2 ports are API compatible. + +The ROS 2 port has been bootstrapped as version `2.0.3`. The `MAJOR` +version has been set to `2` to indicate ROS 2 and the `MINOR` and `PATCH` +versions match the ROS 1 version from which the ROS 2 port was +bootstrapped. The reasoning behind this is to allow for creating tags in this +fork that can be released into the ROS 2 distribution while not conflicting +with existing tags on the upstream repository. A release into Foxy is still +pending. + +This module contains two core functions: + +* `arr = numpify(msg, ...)` - try to get a numpy object from a message +* `msg = msgify(MessageType, arr, ...)` - try and convert a numpy object to a message + +Currently supports: + +* `sensor_msgs.msg.PointCloud2` ↔ structured `np.array`: + + ```python + data = np.zeros(100, dtype=[ + ('x', np.float32), + ('y', np.float32), + ('vectors', np.float32, (3,)) + ]) + data['x'] = np.arange(100) + data['y'] = data['x']*2 + data['vectors'] = np.arange(100)[:,np.newaxis] + + msg = ros2_numpy.msgify(PointCloud2, data) + ``` + + ``` + data = ros2_numpy.numpify(msg) + ``` + +* `sensor_msgs.msg.Image` ↔ 2/3-D `np.array`, similar to the function of `cv_bridge`, but without the dependency on `cv2` +* `nav_msgs.msg.OccupancyGrid` ↔ `np.ma.array` +* `geometry.msg.Vector3` ↔ 1-D `np.array`. `hom=True` gives `[x, y, z, 0]` +* `geometry.msg.Point` ↔ 1-D `np.array`. `hom=True` gives `[x, y, z, 1]` +* `geometry.msg.Quaternion` ↔ 1-D `np.array`, `[x, y, z, w]` +* `geometry.msg.Transform` ↔ 4×4 `np.array`, the homogeneous transformation matrix +* `geometry.msg.Pose` ↔ 4×4 `np.array`, the homogeneous transformation matrix from the origin + +Support for more types can be added with: + +```python +@ros2_numpy.converts_to_numpy(SomeMessageClass) +def convert(my_msg): + return np.array(...) + +@ros2_numpy.converts_from_numpy(SomeMessageClass) +def convert(my_array): + return SomeMessageClass(...) +``` + +Any extra args or kwargs to `numpify` or `msgify` will be forwarded to your conversion function + + +## Future work + +* Add simple conversions for: + + * `geometry_msgs.msg.Inertia` diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/__init__.py b/pkgs/node_helpers/node_helpers/ros2_numpy/__init__.py new file mode 100644 index 0000000..25e9744 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/__init__.py @@ -0,0 +1,12 @@ +""" +A module for converting ROS message types into numpy types, where appropriate +""" + +# isort: skip_file + +from .registry import numpify, msgify +from . import point_cloud2 +from . import image +from . import occupancy_grid +from . import geometry +from . import laser_scan diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/geometry.py b/pkgs/node_helpers/node_helpers/ros2_numpy/geometry.py new file mode 100644 index 0000000..dacd576 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/geometry.py @@ -0,0 +1,173 @@ +import numpy as np +import numpy.typing as npt +import tf_transformations as transformations +from geometry_msgs.msg import Point, Pose, Quaternion, Transform, Vector3 + +from . import numpify +from .registry import converts_from_numpy, converts_to_numpy + +# basic types + + +@converts_to_numpy(Vector3) +def vector3_to_numpy(msg: Vector3, hom: bool = False) -> npt.NDArray[np.float64]: + if hom: + return np.array([msg.x, msg.y, msg.z, 0]) + else: + return np.array([msg.x, msg.y, msg.z]) + + +@converts_from_numpy(Vector3) +def numpy_to_vector3(arr: npt.NDArray[np.float64]) -> Vector3: + if arr.shape[-1] == 4: + assert np.all(arr[..., -1] == 0) + arr = arr[..., :-1] + + if len(arr.shape) == 1: + return Vector3(**dict(zip(["x", "y", "z"], arr, strict=False))) + else: + return np.apply_along_axis( + lambda v: Vector3(**dict(zip(["x", "y", "z"], v, strict=False))), + axis=-1, + arr=arr, + ) + + +@converts_to_numpy(Point) +def point_to_numpy(msg: Point, hom: bool = False) -> npt.NDArray[np.float64]: + if hom: + return np.array([msg.x, msg.y, msg.z, 1]) + else: + return np.array([msg.x, msg.y, msg.z]) + + +@converts_from_numpy(Point) +def numpy_to_point(arr: npt.NDArray[np.float64]) -> Point: + if arr.shape[-1] == 4: + arr = arr[..., :-1] / arr[..., -1] + + if len(arr.shape) == 1: + return Point(**dict(zip(["x", "y", "z"], arr, strict=True))) + else: + return np.apply_along_axis( + lambda v: Point(**dict(zip(["x", "y", "z"], v, strict=True))), + axis=-1, + arr=arr, + ) + + +@converts_to_numpy(Quaternion) +def quat_to_numpy(msg: Quaternion) -> npt.NDArray[np.float64]: + return np.array([msg.x, msg.y, msg.z, msg.w]) + + +@converts_from_numpy(Quaternion) +def numpy_to_quat(arr: npt.NDArray[np.float64]) -> Quaternion: + assert arr.shape[-1] == 4 + + if len(arr.shape) == 1: + return Quaternion(**dict(zip(["x", "y", "z", "w"], arr, strict=False))) + else: + return np.apply_along_axis( + lambda v: Quaternion(**dict(zip(["x", "y", "z", "w"], v, strict=False))), + axis=-1, + arr=arr, + ) + + +# compound types +# all of these take ...x4x4 homogeneous matrices + + +@converts_to_numpy(Transform) +def transform_to_numpy(msg: Transform) -> npt.NDArray[np.float64]: + return np.dot( + transformations.translation_matrix(numpify(msg.translation)), + transformations.quaternion_matrix(numpify(msg.rotation)), + ) + + +@converts_from_numpy(Transform) +def numpy_to_transform(arr: npt.NDArray[np.float64]) -> Transform: + shape, rest = arr.shape[:-2], arr.shape[-2:] + assert rest == (4, 4) + + if len(shape) == 0: + trans = transformations.translation_from_matrix(arr) + quat = transformations.quaternion_from_matrix(arr) + + return Transform( + translation=Vector3(**dict(zip(["x", "y", "z"], trans, strict=True))), + rotation=Quaternion(**dict(zip(["x", "y", "z", "w"], quat, strict=True))), + ) + else: + res = np.empty(shape, dtype=np.object_) + for idx in np.ndindex(shape): + res[idx] = Transform( + translation=Vector3( + **dict( + zip( + ["x", "y", "z"], + transformations.translation_from_matrix(arr[idx]), + strict=True, + ) + ) + ), + rotation=Quaternion( + **dict( + zip( + ["x", "y", "z", "w"], + transformations.quaternion_from_matrix(arr[idx]), + strict=True, + ) + ) + ), + ) + + +@converts_to_numpy(Pose) +def pose_to_numpy(msg: Pose) -> npt.NDArray[np.float64]: + return np.dot( + transformations.translation_matrix(numpify(msg.position)), + transformations.quaternion_matrix(numpify(msg.orientation)), + ) + + +@converts_from_numpy(Pose) +def numpy_to_pose(arr: npt.NDArray[np.float64]) -> Pose: + shape, rest = arr.shape[:-2], arr.shape[-2:] + assert rest == (4, 4) + + if len(shape) == 0: + trans = transformations.translation_from_matrix(arr) + quat = transformations.quaternion_from_matrix(arr) + + return Pose( + position=Point(**dict(zip(["x", "y", "z"], trans, strict=False))), + orientation=Quaternion( + **dict(zip(["x", "y", "z", "w"], quat, strict=False)) + ), + ) + else: + res = np.empty(shape, dtype=np.object_) + for idx in np.ndindex(shape): + res[idx] = Pose( + position=Point( + **dict( + zip( + ["x", "y", "z"], + transformations.translation_from_matrix(arr[idx]), + strict=True, + ) + ) + ), + orientation=Quaternion( + **dict( + zip( + ["x", "y", "z", "w"], + transformations.quaternion_from_matrix(arr[idx]), + strict=True, + ) + ) + ), + ) diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/image.py b/pkgs/node_helpers/node_helpers/ros2_numpy/image.py new file mode 100644 index 0000000..1988a00 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/image.py @@ -0,0 +1,113 @@ +import sys + +import numpy as np +import numpy.typing as npt +from sensor_msgs.msg import Image + +from .registry import converts_from_numpy, converts_to_numpy + +name_to_dtypes = { + "rgb8": (np.uint8, 3), + "rgba8": (np.uint8, 4), + "rgb16": (np.uint16, 3), + "rgba16": (np.uint16, 4), + "bgr8": (np.uint8, 3), + "bgra8": (np.uint8, 4), + "bgr16": (np.uint16, 3), + "bgra16": (np.uint16, 4), + "mono8": (np.uint8, 1), + "mono16": (np.uint16, 1), + # for bayer image (based on cv_bridge.cpp) + "bayer_rggb8": (np.uint8, 1), + "bayer_bggr8": (np.uint8, 1), + "bayer_gbrg8": (np.uint8, 1), + "bayer_grbg8": (np.uint8, 1), + "bayer_rggb16": (np.uint16, 1), + "bayer_bggr16": (np.uint16, 1), + "bayer_gbrg16": (np.uint16, 1), + "bayer_grbg16": (np.uint16, 1), + # OpenCV CvMat types + "8UC1": (np.uint8, 1), + "8UC2": (np.uint8, 2), + "8UC3": (np.uint8, 3), + "8UC4": (np.uint8, 4), + "8SC1": (np.int8, 1), + "8SC2": (np.int8, 2), + "8SC3": (np.int8, 3), + "8SC4": (np.int8, 4), + "16UC1": (np.uint16, 1), + "16UC2": (np.uint16, 2), + "16UC3": (np.uint16, 3), + "16UC4": (np.uint16, 4), + "16SC1": (np.int16, 1), + "16SC2": (np.int16, 2), + "16SC3": (np.int16, 3), + "16SC4": (np.int16, 4), + "32SC1": (np.int32, 1), + "32SC2": (np.int32, 2), + "32SC3": (np.int32, 3), + "32SC4": (np.int32, 4), + "32FC1": (np.float32, 1), + "32FC2": (np.float32, 2), + "32FC3": (np.float32, 3), + "32FC4": (np.float32, 4), + "64FC1": (np.float64, 1), + "64FC2": (np.float64, 2), + "64FC3": (np.float64, 3), + "64FC4": (np.float64, 4), +} + + +@converts_to_numpy(Image) +def image_to_numpy(msg: Image) -> npt.NDArray[np.float64]: + if msg.encoding not in name_to_dtypes: + raise TypeError(f"Unrecognized encoding {msg.encoding}") + + dtype_class, channels = name_to_dtypes[msg.encoding] + dtype = np.dtype(dtype_class) + dtype = dtype.newbyteorder(">" if msg.is_bigendian else "<") + shape = (msg.height, msg.width, channels) + + data = np.frombuffer(msg.data, dtype=dtype).reshape(shape) + data.strides = (msg.step, dtype.itemsize * channels, dtype.itemsize) + + if channels == 1: + data = data[..., 0] + return data + + +@converts_from_numpy(Image) +def numpy_to_image(arr: npt.NDArray[np.float64], encoding: str) -> Image: + if encoding not in name_to_dtypes: + raise TypeError(f"Unrecognized encoding {encoding}") + + im = Image(encoding=encoding) + + # extract width, height, and channels + dtype_class, exp_channels = name_to_dtypes[encoding] + if len(arr.shape) == 2: + im.height, im.width, channels = arr.shape + (1,) + elif len(arr.shape) == 3: + im.height, im.width, channels = arr.shape + else: + raise TypeError("Array must be two or three dimensional") + + # check type and channels + if exp_channels != channels: + raise TypeError( + f"Array has {channels} channels, {encoding} requires {exp_channels}" + ) + if dtype_class != arr.dtype.type: + raise TypeError(f"Array is {arr.dtype.type}, {encoding} requires {dtype_class}") + + # make the array contiguous in memory, as mostly required by the format + contig = np.ascontiguousarray(arr) + im.data = contig.tobytes() + im.step = contig.strides[0] + im.is_bigendian = ( + arr.dtype.byteorder == ">" + or arr.dtype.byteorder == "=" + and sys.byteorder == "big" + ) + + return im diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/laser_scan.py b/pkgs/node_helpers/node_helpers/ros2_numpy/laser_scan.py new file mode 100644 index 0000000..8e5c598 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/laser_scan.py @@ -0,0 +1,159 @@ +""" +Methods to numpify LaserScan message. +""" + +import numpy as np +import numpy.typing as npt +from sensor_msgs.msg import LaserScan +from std_msgs.msg import Header + +from .registry import converts_from_numpy, converts_to_numpy + + +@converts_to_numpy(LaserScan) +def laserscan_to_array( + scan: LaserScan, + remove_invalid_ranges: bool = False, + include_ranges_and_intensities: bool = False, +) -> npt.NDArray[np.float32]: + """ + Takes a sensor_msgs/msg/LaserScan msg and returns a structered array with + fields x, y and z that correspond to cartesian position data. Optionally, + ranges and intensities fields that correspond to the range and intensity + of a point are also included if include_ranges_and_intensities is True. + + + :param scan: Input laser scan message to get numpyed + :param remove_invalid_ranges: whether to remove invalid ranges from the input scan, + by default False + :param include_ranges_and_intensities: whether to also return the ranges & + intensities along with the cartesian position. + + :returns: numpy k-array where each element will be a structured record that + contains either 3 or 5 fields. If include_ranges_and_intensities is + False, each element of output array is a structured record that has + 3 fields ['x', 'y', 'z'] of type float 32. Else, it has 5 fields ['x', + 'y', 'z', 'ranges', 'intensities']. Since output is a structured array, + all the x-coordinates of the points can be accessed as out_array['x']. + Similarly, the y and z coodinates can be accessed as out_array['y'] + and out_array['z'] respectively. + """ + n_points = len(scan.ranges) + angles = np.linspace( + scan.angle_min, + scan.angle_max, + n_points, + ) + ranges = np.array(scan.ranges, dtype="f4") + intensities = np.array(scan.intensities, dtype="f4") + if remove_invalid_ranges: + indices_invalid_range = ( + np.isinf(ranges) + | np.isnan(ranges) + | (ranges < scan.range_min) + | (ranges > scan.range_max) + ) + ranges = ranges[~indices_invalid_range] + angles = angles[~indices_invalid_range] + intensities = intensities[~indices_invalid_range] + + x = np.array(ranges * np.cos(angles), dtype="f4") + y = np.array(ranges * np.sin(angles), dtype="f4") + z = np.zeros(ranges.shape[0], dtype="f4") + if include_ranges_and_intensities: + dtype = np.dtype( + [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("ranges", "f4"), + ("intensities", "f4"), + ] + ) + out_array = np.empty(len(x), dtype=dtype) + out_array["x"] = x + out_array["y"] = y + out_array["z"] = z + out_array["ranges"] = ranges + out_array["intensities"] = intensities + return out_array + else: + dtype = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")]) + out_array = np.empty(len(x), dtype=dtype) + out_array["x"] = x + out_array["y"] = y + out_array["z"] = z + return out_array + + +@converts_from_numpy(LaserScan) +def array_to_laserscan( + arr: npt.NDArray[np.float32], + header: Header, + scan_time: float = 0.0, + time_increment: float = 0.0, +) -> LaserScan: + """ + Takes a structured array(created from LaserScan msg) and returns a + LaserScan message. Fields that cannot be determined from the numpy + array are provided as inputs. Since the LaserScan message relies + on a consistent angular increment, structured array in which the + points have been omitted will result in a LaserScan message that is + not correct. + + :param arr: Structured numpy array with fields x, y and z. + Input numpy array that was created from LaserScan message. + :param header: std_msgs::msg::Header, + The header to be written to the output LaserScan message. + :param scan_time: float, optional + time between scans [seconds] of LaserScan + :param time_increment: float, optional + time between measurements [seconds] - if lidar is moving, + this will be used in interpolating position + + :returns: sensor_msgs::msg::LaserScan message + If the input array does not contain an `intensities` field, + this message has the intensities list filled to zeros. + + """ + n_points = arr.shape[0] + + if "intensities" in arr.dtype.names: + intensities = arr["intensities"] + else: + intensities = np.zeros(n_points).astype(float) + + if "ranges" in arr.dtype.names: + ranges = arr["ranges"] + else: + ranges = np.sqrt(arr["x"] ** 2 + arr["y"] ** 2).astype(float) + + angles = np.arctan2(arr["y"], arr["x"]).astype(float) + + # Create a LaserScan message + scan_msg = LaserScan() + + scan_msg.header = header + + scan_msg.intensities = intensities.tolist() + scan_msg.ranges = ranges.tolist() + + # Compute min and max of the ranges + scan_msg.range_min = np.min(ranges) + scan_msg.range_max = np.max(ranges) + + # Compute min and max of the angles + scan_msg.angle_min = np.min(angles) + scan_msg.angle_max = np.max(angles) + + # Use the time_increment and scan_time from input arguments + scan_msg.time_increment = time_increment + scan_msg.scan_time = scan_time + + # Compute angle increment. Since the angle_max is not exclusive, omit + # the last point. + scan_msg.angle_increment = (scan_msg.angle_max - scan_msg.angle_min) / ( + n_points - 1 + ) + + return scan_msg diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/occupancy_grid.py b/pkgs/node_helpers/node_helpers/ros2_numpy/occupancy_grid.py new file mode 100644 index 0000000..e2f118b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/occupancy_grid.py @@ -0,0 +1,36 @@ +from array import array + +import numpy as np +import numpy.typing as npt +from nav_msgs.msg import MapMetaData, OccupancyGrid + +from .registry import converts_from_numpy, converts_to_numpy + + +@converts_to_numpy(OccupancyGrid) +def occupancygrid_to_numpy(msg: OccupancyGrid) -> npt.NDArray[np.float64]: + data = np.asarray(msg.data, dtype=np.int8).reshape(msg.info.height, msg.info.width) + + return np.ma.array(data, mask=data == -1, fill_value=-1) + + +@converts_from_numpy(OccupancyGrid) +def numpy_to_occupancy_grid( + arr: npt.NDArray[np.float64], info: MapMetaData = None +) -> OccupancyGrid: + if not len(arr.shape) == 2: + raise TypeError("Array must be 2D") + if not arr.dtype == np.int8: + raise TypeError("Array must be of int8s") + + grid = OccupancyGrid() + if isinstance(arr, np.ma.MaskedArray): + # We assume that the masked value are already -1, for speed + arr = arr.data + + grid.data = array("b", arr.ravel().astype(np.int8)) + grid.info = info or MapMetaData() + grid.info.height = arr.shape[0] + grid.info.width = arr.shape[1] + + return grid diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/point_cloud2.py b/pkgs/node_helpers/node_helpers/ros2_numpy/point_cloud2.py new file mode 100644 index 0000000..91a3f73 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/point_cloud2.py @@ -0,0 +1,307 @@ +# Software License Agreement (BSD License) +# +# Copyright (c) 2008, Willow Garage, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# * Neither the name of Willow Garage, Inc. nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Jon Binney + +""" +Functions for working with PointCloud2. +""" + +__docformat__ = "restructuredtext en" + +import array +import sys +from collections.abc import Sequence +from typing import Any + +import numpy as np +import numpy.typing as npt +from builtin_interfaces.msg import Time +from sensor_msgs.msg import PointCloud2, PointField + +from .registry import converts_from_numpy, converts_to_numpy + +# prefix to the names of dummy fields we add to get byte alignment +# correct. this needs to not clash with any actual field names +DUMMY_FIELD_PREFIX = "__" + +# mappings between PointField types and numpy types +type_mappings = [ + (PointField.INT8, np.dtype("int8")), + (PointField.UINT8, np.dtype("uint8")), + (PointField.INT16, np.dtype("int16")), + (PointField.UINT16, np.dtype("uint16")), + (PointField.INT32, np.dtype("int32")), + (PointField.UINT32, np.dtype("uint32")), + (PointField.FLOAT32, np.dtype("float32")), + (PointField.FLOAT64, np.dtype("float64")), +] +pftype_to_nptype = dict(type_mappings) +nptype_to_pftype = {nptype: pftype for pftype, nptype in type_mappings} + + +@converts_to_numpy(PointField, plural=True) +def fields_to_dtype(fields: Sequence[PointField], point_step: int) -> Any: + """Convert a list of PointFields to a numpy record datatype.""" + offset = 0 + np_dtype_list = [] + for f in fields: + while offset < f.offset: + # might be extra padding between fields + np_dtype_list.append(("%s%d" % (DUMMY_FIELD_PREFIX, offset), np.uint8)) + offset += 1 + + dtype = pftype_to_nptype[f.datatype] + if f.count != 1: + dtype = np.dtype((dtype, f.count)) + + np_dtype_list.append((f.name, dtype)) + offset += pftype_to_nptype[f.datatype].itemsize * f.count + + # might be extra padding between points + while offset < point_step: + np_dtype_list.append(("%s%d" % (DUMMY_FIELD_PREFIX, offset), np.uint8)) + offset += 1 + + return np_dtype_list + + +@converts_from_numpy(PointField, plural=True) +def dtype_to_fields(dtype: Any) -> Sequence[PointField]: + """Convert a numpy record datatype into a list of PointFields.""" + fields = [] + for field_name in dtype.names: + np_field_type, field_offset = dtype.fields[field_name] + pf = PointField() + pf.name = field_name + if np_field_type.subdtype: + item_dtype, shape = np_field_type.subdtype + pf.count = int(np.prod(shape)) + np_field_type = item_dtype + else: + pf.count = 1 + + pf.datatype = nptype_to_pftype[np_field_type] + pf.offset = field_offset + fields.append(pf) + return fields + + +@converts_to_numpy(PointCloud2) +def pointcloud2_to_array( + cloud_msg: PointCloud2, squeeze: bool = True +) -> npt.NDArray[np.float64]: + """Converts a rospy PointCloud2 message to a numpy recordarray + + Reshapes the returned array to have shape (height, width), even if the + height is 1. + + The reason for using np.frombuffer rather than struct.unpack is + speed... especially for large point clouds, this will be faster. + + :param cloud_msg: The message to convert + :param squeeze: If True, and the height is 1, the array is reshaped to width + :return: The numpy array + """ + # construct a numpy record type equivalent to the point type of this cloud + dtype_list = fields_to_dtype(cloud_msg.fields, cloud_msg.point_step) + + # parse the cloud into an array + cloud_arr = np.frombuffer(cloud_msg.data, dtype_list) + + # remove the dummy fields that were added + cloud_arr = cloud_arr[ + [ + fname + for fname, _type in dtype_list + if fname[: len(DUMMY_FIELD_PREFIX)] != DUMMY_FIELD_PREFIX + ] + ] + + if squeeze and cloud_msg.height == 1: + return np.reshape(cloud_arr, (cloud_msg.width,)) + else: + return np.reshape(cloud_arr, (cloud_msg.height, cloud_msg.width)) + + +@converts_from_numpy(PointCloud2) +def array_to_pointcloud2( + cloud_arr: npt.NDArray[np.float64], + stamp: Time | None = None, + frame_id: str | None = None, +) -> PointCloud2: + """Converts a numpy record array to a sensor_msgs.msg.PointCloud2.""" + # make it 2d (even if height will be 1) + cloud_arr = np.atleast_2d(cloud_arr) + + cloud_msg = PointCloud2() + + if stamp is not None: + cloud_msg.header.stamp = stamp + if frame_id is not None: + cloud_msg.header.frame_id = frame_id + cloud_msg.height = cloud_arr.shape[0] + cloud_msg.width = cloud_arr.shape[1] + cloud_msg.fields = dtype_to_fields(cloud_arr.dtype) + cloud_msg.is_bigendian = sys.byteorder != "little" + cloud_msg.point_step = cloud_arr.dtype.itemsize + cloud_msg.row_step = cloud_msg.point_step * cloud_arr.shape[1] + cloud_msg.is_dense = all( + np.isfinite(cloud_arr[fname]).all() for fname in cloud_arr.dtype.names + ) + + # The PointCloud2.data setter will create an array.array object for you if you don't + # provide it one directly. This causes very slow performance because it iterates + # over each byte in python. + # Here we create an array.array object using a memoryview, limiting copying and + # increasing performance. + memory_view = memoryview(cloud_arr) + # Casting raises a TypeError if the array has no elements + array_bytes = memory_view.cast("B") if memory_view.nbytes > 0 else b"" + as_array = array.array("B") + as_array.frombytes(array_bytes) + cloud_msg.data = as_array + return cloud_msg + + +def merge_rgb_fields(cloud_arr: Any) -> npt.NDArray[np.float64]: + """Takes an array with named np.uint8 fields 'r', 'g', and 'b', and returns + an array in which they have been merged into a single np.float32 'rgb' + field. The first byte of this field is the 'r' uint8, the second is the + 'g', uint8, and the third is the 'b' uint8. + + This is the way that pcl likes to handle RGB colors for some reason. + + :param cloud_arr: The array to convert + :return: An array with merged color values + """ + r = np.asarray(cloud_arr["r"], dtype=np.uint32) + g = np.asarray(cloud_arr["g"], dtype=np.uint32) + b = np.asarray(cloud_arr["b"], dtype=np.uint32) + rgb_arr = np.array((r << 16) | (g << 8) | (b << 0), dtype=np.uint32) + + # not sure if there is a better way to do this. i'm changing the type of + # the array from uint32 to float32, but i don't want any conversion to take + # place -jdb + rgb_arr.dtype = np.float32 + + # create a new array, without r, g, and b, but with rgb float32 field + new_dtype = [] + for field_name in cloud_arr.dtype.names: + field_type, field_offset = cloud_arr.dtype.fields[field_name] + if field_name not in ("r", "g", "b"): + new_dtype.append((field_name, field_type)) + new_dtype.append(("rgb", np.float32)) + new_cloud_arr = np.zeros(cloud_arr.shape, new_dtype) + + # fill in the new array + for field_name in new_cloud_arr.dtype.names: + if field_name == "rgb": + new_cloud_arr[field_name] = rgb_arr + else: + new_cloud_arr[field_name] = cloud_arr[field_name] + + return new_cloud_arr + + +def split_rgb_field(cloud_arr: Any) -> npt.NDArray[np.float64]: + """Takes an array with a named 'rgb' float32 field, and returns an array in + which this has been split into 3 uint 8 fields: 'r', 'g', and 'b'. + + (pcl stores rgb in packed 32 bit floats) + + :param cloud_arr: The array to convert + :return: An array with split color values + """ + rgb_arr = cloud_arr["rgb"].copy() + rgb_arr.dtype = np.uint32 + r = np.asarray((rgb_arr >> 16) & 255, dtype=np.uint8) + g = np.asarray((rgb_arr >> 8) & 255, dtype=np.uint8) + b = np.asarray(rgb_arr & 255, dtype=np.uint8) + + # create a new array, without rgb, but with r, g, and b fields + new_dtype = [] + for field_name in cloud_arr.dtype.names: + field_type, field_offset = cloud_arr.dtype.fields[field_name] + if field_name != "rgb": + new_dtype.append((field_name, field_type)) + new_dtype.append(("r", np.uint8)) + new_dtype.append(("g", np.uint8)) + new_dtype.append(("b", np.uint8)) + new_cloud_arr = np.zeros(cloud_arr.shape, new_dtype) + + # fill in the new array + for field_name in new_cloud_arr.dtype.names: + if field_name == "r": + new_cloud_arr[field_name] = r + elif field_name == "g": + new_cloud_arr[field_name] = g + elif field_name == "b": + new_cloud_arr[field_name] = b + else: + new_cloud_arr[field_name] = cloud_arr[field_name] + return new_cloud_arr + + +def get_xyz_points( + cloud_array: Any, remove_nans: bool = True, dtype: type = float +) -> npt.NDArray[np.float64]: + """Pulls out x, y, and z columns from the cloud recordarray, and returns + a 3xN matrix. + + :param cloud_array: The array to convert + :param remove_nans: If True, removes points with NaN coordinates + :param dtype: The dtype of the returned array + :return: A 3xN array of points + """ + # remove crap points + if remove_nans: + mask = ( + np.isfinite(cloud_array["x"]) + & np.isfinite(cloud_array["y"]) + & np.isfinite(cloud_array["z"]) + ) + cloud_array = cloud_array[mask] + + # pull out x, y, and z values + points = np.zeros(cloud_array.shape + (3,), dtype=dtype) + points[..., 0] = cloud_array["x"] + points[..., 1] = cloud_array["y"] + points[..., 2] = cloud_array["z"] + + return points + + +def pointcloud2_to_xyz_array( + cloud_msg: PointCloud2, remove_nans: bool = True +) -> npt.NDArray[np.float64]: + return get_xyz_points(pointcloud2_to_array(cloud_msg), remove_nans=remove_nans) diff --git a/pkgs/node_helpers/node_helpers/ros2_numpy/registry.py b/pkgs/node_helpers/node_helpers/ros2_numpy/registry.py new file mode 100644 index 0000000..25d20ac --- /dev/null +++ b/pkgs/node_helpers/node_helpers/ros2_numpy/registry.py @@ -0,0 +1,54 @@ +from collections.abc import Sequence +from typing import Any + +_to_numpy = {} +_from_numpy = {} + + +def converts_to_numpy(msgtype: Any, plural: bool = False) -> Any: + def decorator(f: Any) -> Any: + _to_numpy[msgtype, plural] = f + return f + + return decorator + + +def converts_from_numpy(msgtype: Any, plural: bool = False) -> Any: + def decorator(f: Any) -> Any: + _from_numpy[msgtype, plural] = f + return f + + return decorator + + +def numpify(msg: Any, *args: Any, **kwargs: Any) -> Any: + if msg is None: + return None + + conv = _to_numpy.get((msg.__class__, False)) + if not conv and isinstance(msg, Sequence): + if not msg: + raise ValueError("Cannot determine the type of an empty Collection") + conv = _to_numpy.get((msg[0].__class__, True)) + + if not conv: + raise ValueError( + "Unable to convert message {} - only supports {}".format( + msg.__class__.__name__, + ", ".join(cls.__name__ + ("[]" if pl else "") for cls, pl in _to_numpy), + ) + ) + + return conv(msg, *args, **kwargs) + + +def msgify(msg_type: Any, numpy_obj: Any, *args: Any, **kwargs: Any) -> Any: + conv = _from_numpy.get((msg_type, kwargs.pop("plural", False))) + if not conv: + raise ValueError( + "Unable to build message {} - only supports {}".format( + msg_type.__name__, + ", ".join(cls.__name__ + ("[]" if pl else "") for cls, pl in _to_numpy), + ) + ) + return conv(numpy_obj, *args, **kwargs) diff --git a/pkgs/node_helpers/node_helpers/sensors/README.md b/pkgs/node_helpers/node_helpers/sensors/README.md new file mode 100644 index 0000000..355be7a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/README.md @@ -0,0 +1,12 @@ +# node_helpers.sensors + +The `node_helpers.sensors` module provides a standardized framework for handling sensor data in ROS2. It includes tools for publishing, buffering, and visualizing sensor messages, making it easy to implement and reuse components across different sensor types. + +Key features: +- **Publishers**: `BaseSensorPublisher` facilitates structured sensor message publishing with built-in support for RViz visualization. +- **Buffers**: `BaseSensorBuffer` holds and updates the latest sensor readings for on-demand access. +- **Predefined Sensors**: Includes reusable components for binary sensors and rangefinders. + +This framework simplifies creating and visualizing new sensors while promoting best practices for consistent sensor message handling. + +For detailed documentation, see [docs/](../../../../docs/sensors.rst) diff --git a/pkgs/node_helpers/node_helpers/sensors/__init__.py b/pkgs/node_helpers/node_helpers/sensors/__init__.py new file mode 100644 index 0000000..5d9146a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/__init__.py @@ -0,0 +1,10 @@ +from .base_buffer import BaseSensorBuffer +from .base_publisher import BaseSensorPublisher + +from .binary_signal import ( # isort: skip + BinarySensorBuffer, # isort: skip + BinarySensorFromRangeFinder, # isort: skip + BinarySensor, # isort: skip +) # isort: skip +from .rangefinder import RangefinderBuffer, RangefinderPublisher +from .typing import SensorProtocol diff --git a/pkgs/node_helpers/node_helpers/sensors/base_buffer.py b/pkgs/node_helpers/node_helpers/sensors/base_buffer.py new file mode 100644 index 0000000..2c666ce --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/base_buffer.py @@ -0,0 +1,114 @@ +import logging +from abc import ABC +from typing import Generic, TypeVar, cast + +from builtin_interfaces.msg import Time +from rclpy.callback_groups import CallbackGroup +from rclpy.node import Node +from rclpy.qos import QoSProfile, qos_profile_sensor_data +from rclpy.time import Time as RclpyTime + +from node_helpers.pubsub import Topic +from node_helpers.tf import timestamps +from node_helpers.timing import Timeout + +from .typing import SensorProtocol + +SENSOR_MSG = TypeVar("SENSOR_MSG", bound=SensorProtocol) + + +class BaseSensorBuffer(ABC, Generic[SENSOR_MSG]): + """A sensor buffer is simple: It subscribes to a sensor topic, and holds the latest + value for the user. + """ + + def __init__( + self, + node: Node, + msg_type: SENSOR_MSG, + sensor_topic: str, + sensor_qos: QoSProfile = qos_profile_sensor_data, + callback_group: CallbackGroup | None = None, + ): + super().__init__() + # Topics + self.on_value_change = Topic[SENSOR_MSG]() + """Called when a reading has changed from a previous value""" + self.on_receive = Topic[SENSOR_MSG]() + """Called whenever a new reading is received""" + + self._latest_reading: SENSOR_MSG | None = None + node.create_subscription( + msg_type=msg_type, + topic=sensor_topic, + callback=self._on_receive, + qos_profile=sensor_qos, + callback_group=callback_group, + ) + + def _on_receive(self, msg: SENSOR_MSG) -> None: + """Called whenever a sensor reading is received""" + changed = ( + self._latest_reading is None or msg.value != self._latest_reading.value + ) + + # Verify if the message is newer than the latest one + if self._latest_reading is not None and timestamps.is_older( + msg.header.stamp, + self._latest_reading.header.stamp, + ): + logging.error( + f"Refusing to receive out-of-order message for {msg.header.frame_id}!" + ) + return + + self._latest_reading = msg + + # Alert users of the callback API + if changed: + self.on_value_change.publish(self._latest_reading) + self.on_receive.publish(self._latest_reading) + + def get(self, after: Time = None, timeout: float | None = None) -> SENSOR_MSG: + """Get a sensor message from the buffer + + :param after: Get a sensor reading at or after a given time. If None, the latest + reading in the buffer will be returned. If 0, the latest_reading will be + checked and returned if it fits the 'after' criteria, otherwise a + TimeoutError will immediately be raised. + :param timeout: If None, block until the 'after' criteria is met. If 0, check + instantaneously if the 'after' criteria is met, and if not fail immediately. + If greater than 0, wait that amount of seconds until the 'after' criteria + is met. + :raises RuntimeError: In impossible situations + :return: The sensor message + """ + + # First, check if the latest value already in the buffer matches the requirement + after = RclpyTime.from_msg(after) if after is not None else None + if self._latest_reading is not None and ( + after is None + or RclpyTime.from_msg(self._latest_reading.header.stamp) > after + ): + return self._latest_reading + + on_receive = self.on_receive.subscribe_as_event() + check_timeout = ( + Timeout(timeout, raise_error=True) if timeout is not None else True + ) + while check_timeout: + if not on_receive.wait(timeout=timeout or 10.0): + logging.warning( + f"{self.__class__.__name__}.get() is taking longer than" + f" expected..." + ) + continue + + new_msg = cast(SENSOR_MSG, self._latest_reading) + if after is not None and RclpyTime.from_msg(new_msg.header.stamp) < after: + # This message is still older than the requested 'after' value + continue + + return new_msg + + raise RuntimeError("This function should have returned a value!") diff --git a/pkgs/node_helpers/node_helpers/sensors/base_publisher.py b/pkgs/node_helpers/node_helpers/sensors/base_publisher.py new file mode 100644 index 0000000..6c67b0a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/base_publisher.py @@ -0,0 +1,177 @@ +import logging +from abc import ABC, abstractmethod +from threading import RLock +from typing import Any, Generic, TypeVar + +from builtin_interfaces.msg import Time +from pydantic import BaseModel +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup +from rclpy.node import Node +from rclpy.qos import QoSProfile, qos_profile_sensor_data +from std_msgs.msg import Header +from visualization_msgs.msg import Marker, MarkerArray + +from node_helpers import markers +from node_helpers.tf import timestamps +from node_helpers.timing import ttl_cached + +from .typing import SensorProtocol + +SENSOR_MSG = TypeVar("SENSOR_MSG", bound=SensorProtocol) +SENSOR_VALUE = TypeVar("SENSOR_VALUE", bound=Any) +PARAMETERS = TypeVar("PARAMETERS", bound="BaseSensorPublisher.Parameters") + +DEFAULT_SENSORS_VIS_TOPIC = "/debug/sensors" +WRAPPED_FN = TypeVar("WRAPPED_FN") + + +class BaseSensorPublisher(ABC, Generic[SENSOR_MSG, SENSOR_VALUE, PARAMETERS]): + """The basic sensor publisher structure. + It standardizes sensor publishing by assigning a default QoS, standardizing the + visualization callbacks, and standardizing sensor message structure. + """ + + class Parameters(BaseModel): + frame_id: str + """The frame ID wherein the sensor is centered upon the origin.""" + + sensor_topic: str + """The topic to publish sensor values to.""" + + vis_topic: str = DEFAULT_SENSORS_VIS_TOPIC + """The topic to publish visualization markers to.""" + + sensor_publishing_max_hz: float = 0.0 + """Optionally throttle the sensor publishing rate. 0.0 means no throttle.""" + + vis_publishing_max_hz: float = 0.0 + """Optionally throttle the visualization publishing rate. 0.0 means no throttle. + """ + + def __init__( + self, + node: Node, + msg_type: type[SENSOR_MSG], + parameters: PARAMETERS, + sensor_qos: QoSProfile = qos_profile_sensor_data, + vis_qos: QoSProfile = qos_profile_sensor_data, + ): + """ + :param node: The node to use for publishing + :param msg_type: The sensor message type + :param parameters: The parameters for the sensor publisher + :param sensor_qos: The QoS for sensor publishing + :param vis_qos: The QoS for visualization publishing + """ + self._params = parameters + self._sensor_topic = self._params.sensor_topic + self._msg_type = msg_type + self._clock = node.get_clock() + self._sensor_publisher = node.create_publisher( + msg_type, + self._sensor_topic, + sensor_qos, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + self._visualization_publisher = node.create_publisher( + MarkerArray, + self._params.vis_topic, + vis_qos, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + self._last_published_time: Time = Time() + """For ensuring there's no out-of-order publishing""" + + self.frame_id = self._params.frame_id + + self._publish_sensor_lock = RLock() + self.publish_sensor = self._maybe_throttle_function( # type: ignore + self._params.sensor_publishing_max_hz, self.publish_sensor + ) + self._publish_rviz_markers = self._maybe_throttle_function( # type: ignore + self._params.vis_publishing_max_hz, self._publish_rviz_markers + ) + + def _maybe_throttle_function(self, max_hz: float, fn: WRAPPED_FN) -> WRAPPED_FN: + """Optionally wrap a function with a rate limiter""" + if max_hz <= 0: + return fn + + logging.info( + f"Throttling sensor '{self.frame_id}' {fn.__name__} to {max_hz}hz" # type: ignore + ) + return ttl_cached(1 / max_hz)(fn) # type: ignore + + @property + def marker_namespace(self) -> str: + return f"{self.frame_id}.{self._sensor_topic}" + + def publish_value( + self, sensor_value: SENSOR_VALUE, stamp: Time | None = None + ) -> None: + """Helper for publishing a stamped sensor message + + :param sensor_value: The Sensor.value portion of the message + :param stamp: If none, clock.now() will be used for the header. + """ + + # Since old timestamps can be rejected, we hold the lock so that we can create + # a new stamp (when needed) and publish it before another thread possibly can. + with self._publish_sensor_lock: + msg = self._msg_type( + header=Header( + frame_id=self.frame_id, stamp=stamp or self._clock.now().to_msg() + ), + value=sensor_value, + ) + self.publish_sensor(msg) + + def publish_sensor(self, sensor_msg: SENSOR_MSG) -> None: + # Ensure the message is filled out + if sensor_msg.header.stamp == Time(): + raise ValueError("A sensor message cannot have an empty timestamp!") + if sensor_msg.header.frame_id != self.frame_id: + raise ValueError( + "Until a need arises, it's not allowed to publish a sensor message with" + " a frame_id different than the publishers assigned frame_id!" + ) + + with self._publish_sensor_lock: + # Ensure that we're not publishing out of order + if timestamps.is_older(sensor_msg.header.stamp, self._last_published_time): + logging.warning( + f"The buffer was told to publish an out of order message! This is " + f"only okay during initialization, when initial values are being " + f"set while upstream messages might be arriving. Sensor: " + f"'{self.frame_id}'. " + f"Message: {sensor_msg}, Latest Time: {self._last_published_time}" + ) + + self._sensor_publisher.publish(sensor_msg) + self._last_published_time = sensor_msg.header.stamp + self._publish_rviz_markers(sensor_msg) + + def _publish_rviz_markers(self, sensor_msg: SENSOR_MSG) -> None: + """Publish a list of rviz markers to the visualization topic""" + # Create visualizations for this sensor + rviz_markers = self.to_rviz_msg(sensor_msg) + # Clean up the markers so that each marker has the same header as the sensor msg + # This is pretty opinionated, and could be a target for change in the future. + for marker in rviz_markers: + if marker.header == Header(): + marker.header = sensor_msg.header + elif not (marker.header.frame_id != "" and marker.header.stamp != Time()): + raise ValueError( + "If a marker header already has has either a frame_id or a " + "time stamp, then it must be fully defined!" + ) + namespaced_marker_array = markers.ascending_id_marker_array( + rviz_markers, marker_namespace=self.marker_namespace + ) + + self._visualization_publisher.publish(namespaced_marker_array) + + @abstractmethod + def to_rviz_msg(self, msg: SENSOR_MSG) -> list[Marker]: + """This method should be able to take a sensor msg and convert it to markers""" + raise NotImplementedError diff --git a/pkgs/node_helpers/node_helpers/sensors/binary_signal.py b/pkgs/node_helpers/node_helpers/sensors/binary_signal.py new file mode 100644 index 0000000..3c4b8bd --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/binary_signal.py @@ -0,0 +1,125 @@ +from typing import TypeVar + +from builtin_interfaces.msg import Duration +from geometry_msgs.msg import Vector3 +from node_helpers_msgs.msg import BinaryReading, RangefinderReading +from rclpy.callback_groups import CallbackGroup +from rclpy.node import Node +from rclpy.qos import qos_profile_action_status_default +from std_msgs.msg import ColorRGBA +from visualization_msgs.msg import Marker + +from ..pubsub import Topic +from .base_buffer import BaseSensorBuffer +from .base_publisher import BaseSensorPublisher +from .rangefinder import RangefinderBuffer + +PARAMETERS = TypeVar("PARAMETERS", bound=BaseSensorPublisher.Parameters) + + +def _to_rviz_msg(msg: BinaryReading) -> list[Marker]: + """Create a cylinder that turns green when the sensor is triggered""" + color = ( + ColorRGBA(r=0.2, g=1.0, b=0.2, a=1.0) + if msg.value + else ColorRGBA(r=0.4, g=0.2, b=0.2, a=1.0) + ) + return [ + Marker( + type=Marker.CYLINDER, + scale=Vector3(x=0.05, y=0.05, z=0.1), + color=color, + lifetime=Duration(sec=60 * 60), + frame_locked=True, + ) + ] + + +class BinarySensor( + BaseSensorPublisher[ + BinaryReading, bool, "BinarySensorFromFieldPublisher.Parameters" + ] +): + def __init__(self, node: Node, parameters: BaseSensorPublisher.Parameters): + super().__init__( + node=node, + msg_type=BinaryReading, + parameters=parameters, + sensor_qos=qos_profile_action_status_default, + ) + + def to_rviz_msg(self, msg: BinaryReading) -> list[Marker]: + return _to_rviz_msg(msg) + + +class BinarySensorFromRangeFinder( + BaseSensorPublisher[BinaryReading, bool, "BinarySensorFromRangeFinder.Parameters"] +): + class Parameters(BaseSensorPublisher.Parameters): + threshold: float + """A threshold in meters. If the rangefinder reading is below this value, the + sensor will be triggered. Use 'invert' to flip this behavior. + """ + + inverted: bool = False + """If True, the sensor will be triggered when the rangefinder reading is above + the threshold.""" + + vis_publishing_max_hz: float = 10.0 + """The Rangefinder sets a default throttle for visualization, since most + rangefinders operate as a stream of data (as opposed to binary sensors), so it + makes sense to not waste rviz resources by publishing at a high rate.""" + + def __init__( + self, + node: Node, + rangefinder: RangefinderBuffer, + parameters: Parameters, + ): + super().__init__( + node=node, + msg_type=BinaryReading, + parameters=parameters, + sensor_qos=qos_profile_action_status_default, + ) + rangefinder.on_value_change.subscribe(self._on_rangefinder_change) + + def _on_rangefinder_change(self, reading: RangefinderReading) -> None: + # Check if the sensor should be triggered + value = reading.value.z < self._params.threshold + if self._params.inverted: + value = not value + + self.publish_value(value, stamp=reading.header.stamp) + + def to_rviz_msg(self, msg: BinaryReading) -> list[Marker]: + return _to_rviz_msg(msg) + + +class BinarySensorBuffer(BaseSensorBuffer[BinaryReading]): + def __init__( + self, node: Node, sensor_topic: str, callback_group: CallbackGroup | None = None + ): + super().__init__( + node=node, + msg_type=BinaryReading, + sensor_topic=sensor_topic, + sensor_qos=qos_profile_action_status_default, + callback_group=callback_group, + ) + + self.on_rising_edge = Topic[BinaryReading]() + """Called when the sensor transitions from False to True""" + + self.on_falling_edge = Topic[BinaryReading]() + """Called when the sensor transitions from True to False""" + + def _on_receive(self, msg: BinaryReading) -> None: + """Called whenever a sensor reading is received""" + if self._latest_reading is not None: + if not self._latest_reading.value and msg.value: + self.on_rising_edge.publish(msg) + elif self._latest_reading.value and not msg.value: + self.on_falling_edge.publish(msg) + + return super()._on_receive(msg) diff --git a/pkgs/node_helpers/node_helpers/sensors/rangefinder.py b/pkgs/node_helpers/node_helpers/sensors/rangefinder.py new file mode 100644 index 0000000..236b99b --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/rangefinder.py @@ -0,0 +1,87 @@ +from geometry_msgs.msg import Point, Pose, Vector3 +from node_helpers_msgs.msg import RangefinderReading +from rclpy.callback_groups import CallbackGroup +from rclpy.node import Node +from rclpy.qos import QoSProfile, qos_profile_services_default +from std_msgs.msg import ColorRGBA +from visualization_msgs.msg import Marker + +from node_helpers import markers + +from .base_buffer import BaseSensorBuffer +from .base_publisher import BaseSensorPublisher + + +class RangefinderPublisher( + BaseSensorPublisher[RangefinderReading, Vector3, "RangefinderPublisher.Parameters"] +): + """This binary signal sensor will publish only when the state changes. + + The QOS is reliable, because the sensor only publishes once per state change. + """ + + class Parameters(BaseSensorPublisher.Parameters): + vis_publishing_max_hz: float = 10.0 + """The Rangefinder sets a default throttle for visualization, since most + rangefinders operate as a stream of data (as opposed to binary sensors), so it + makes sense to not waste rviz resources by publishing at a high rate.""" + + sensor_publishing_max_hz: float = 20.0 + """The Rangefinder publishes on every value received instead of on value change, + so we throttle the sensor publishing rate by default + to avoid gumming up the ROS graph too much.""" + + def __init__(self, node: Node, parameters: Parameters, qos: QoSProfile): + super().__init__( + node=node, + msg_type=RangefinderReading, + parameters=parameters, + sensor_qos=qos, + ) + + def publish_range(self, value: float) -> Vector3: + """Publish a new range value, in meters.""" + self.publish_value(Vector3(z=value)) + + def to_rviz_msg(self, msg: RangefinderReading) -> list[Marker]: + """Create an arrow that reflects the rangefinder reading.""" + color = ( + ColorRGBA(r=0.2, g=1.0, b=0.2, a=1.0) + if msg.value + else ColorRGBA(r=0.4, g=0.2, b=0.2, a=1.0) + ) + return [ + # Mark frame_locked=True to fix flickering effect in rviz. + # This essentially lets rviz render the marker even if it hasn't yet + # received up-to-date TF information. + markers.create_point_to_point_arrow_marker( + shaft_diameter=0.01, + head_diameter=0.02, + head_length=0.03, + base_point=(0.0, 0.0, 0.0), + head_point=msg.value, + color=color, + frame_locked=True, + ), + markers.create_floating_text( + text=f"{msg.value.z:.3f}", + text_height=0.05, + color=color, + # Offset the text from the base of the sensor frame + pose=Pose(position=Point(x=0.05, y=0.05, z=0.05)), + frame_locked=True, + ), + ] + + +class RangefinderBuffer(BaseSensorBuffer[RangefinderReading]): + def __init__( + self, node: Node, sensor_topic: str, callback_group: CallbackGroup | None = None + ): + super().__init__( + node=node, + msg_type=RangefinderReading, + sensor_topic=sensor_topic, + sensor_qos=qos_profile_services_default, + callback_group=callback_group, + ) diff --git a/pkgs/node_helpers/node_helpers/sensors/typing.py b/pkgs/node_helpers/node_helpers/sensors/typing.py new file mode 100644 index 0000000..233f85a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/sensors/typing.py @@ -0,0 +1,16 @@ +from typing import Any, Protocol + +from std_msgs.msg import Header + + +class SensorProtocol(Protocol): + """A sensor message will always have a header, and some value. + + The header.frame_id will be a TF where the sensor is centered upon the origin. + The header.stamp is the timestamp of when that sensor reading was taken. + """ + + header: Header + value: Any + + def __init__(self, header: Header, value: Any): ... diff --git a/pkgs/node_helpers/node_helpers/spinning/__init__.py b/pkgs/node_helpers/node_helpers/spinning/__init__.py new file mode 100644 index 0000000..1f48011 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/spinning/__init__.py @@ -0,0 +1,2 @@ +from .executor import MultiThreadedStackTracedExecutor +from .initialization import DEFAULT_MULTITHREADED_THREAD_COUNT, create_spin_function diff --git a/pkgs/node_helpers/node_helpers/spinning/executor.py b/pkgs/node_helpers/node_helpers/spinning/executor.py new file mode 100644 index 0000000..c758b6f --- /dev/null +++ b/pkgs/node_helpers/node_helpers/spinning/executor.py @@ -0,0 +1,27 @@ +from rclpy.executors import MultiThreadedExecutor + + +class MultiThreadedStackTracedExecutor(MultiThreadedExecutor): + """This modifies the MultiThreadedExecutor to ensure that the 'shutdown()' call + blocks untill all underlying threads have joined. + """ + + def shutdown( + self, timeout_sec: float | None = None, wait_for_threads: bool = True + ) -> bool: + """ + Stop executing callbacks and wait for their completion. + + :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + :param wait_for_threads: If true, this function will exit only once all executor + threads have joined. + :return: ``True`` if all outstanding callbacks finished executing, or ``False`` + if the timeout expires before all outstanding work is done. + """ + success: bool = super().shutdown(timeout_sec) + self._executor.shutdown(wait=wait_for_threads) + return success + + +__all__ = ["MultiThreadedStackTracedExecutor"] diff --git a/pkgs/node_helpers/node_helpers/spinning/initialization.py b/pkgs/node_helpers/node_helpers/spinning/initialization.py new file mode 100644 index 0000000..3209053 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/spinning/initialization.py @@ -0,0 +1,74 @@ +import logging +import sys +from collections.abc import Callable + +import rclpy +from rclpy.executors import SingleThreadedExecutor +from rclpy.node import Node + +from .executor import MultiThreadedStackTracedExecutor + +DEFAULT_MULTITHREADED_THREAD_COUNT = 16 +"""The default thread count should be thought of as "how many concurrent actions or +services should a node be able to handle. In the future, this could be configurable, or +a ScalableThreadedExecutor could be designed which dynamically increases or decreases +thread count. + +This constant is also used for automated tests that use multi threaded executors. +""" + + +def create_spin_function( + node_cls: type[Node], multi_threaded: bool = False +) -> Callable[[], None]: + """Returns a function that, when called, will initialize ROS and spin the given node + Usage Example: + In node.py + >>>main = create_spin_function(MyNode) + + In pyproject.toml + >>>[tool.poetry.scripts] + >>> my_node = "node:main" + + :param node_cls: The node class to initialize and spin + :param multi_threaded: If True, the node will be spun with a multi-threaded executor + :returns: A function that can be run to initialize the node + """ + + def spin_fn() -> None: + # Match the rclpy Node.get_logger() format closely, which is of format: + # [LOG_LEVEL] [UNIX_TIMESTAMP] [NODE_NAMESPACE.NODE_NAME]: the log message + # In the case of python logs we also add the field '[FILENAME:LINE_NUMBER]:' + node_name_placeholder = "unknown.node_name" + log_fmt = ( + f"[%(levelname)s] [%(created)s] [{node_name_placeholder}] " + "[%(filename)s:%(lineno)d]: %(message)s" + ) + logging.basicConfig(level=logging.DEBUG, format=log_fmt) + + rclpy.init(args=sys.argv) + node = node_cls() + + # Now that the node is initialized, we can re-configure the logger so that it + # contains the now-known node name and namespace information + node_namespace = node.get_namespace().replace("/", ".")[1:] + name_and_namespace = f"{node_namespace}.{node.get_name()}" + log_fmt = log_fmt.replace(node_name_placeholder, name_and_namespace) + logging.basicConfig(level=logging.DEBUG, format=log_fmt, force=True) + + # Create the appropriate executor + if multi_threaded: + executor = MultiThreadedStackTracedExecutor( + num_threads=DEFAULT_MULTITHREADED_THREAD_COUNT + ) + else: + executor = SingleThreadedExecutor() + + try: + rclpy.spin(node, executor) + except KeyboardInterrupt: + node.get_logger().warning("Closing due to SIGINT") + node.destroy_node() + rclpy.shutdown() + + return spin_fn diff --git a/pkgs/node_helpers/node_helpers/testing/__init__.py b/pkgs/node_helpers/node_helpers/testing/__init__.py new file mode 100644 index 0000000..e805974 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/__init__.py @@ -0,0 +1,45 @@ +import faulthandler + +from .async_tools import run_and_cancel_task +from .callbacks import ActionServerCallback, ConfigurableServiceCallback +from .fixtures import each_test_setup_teardown +from .generators import exhaust_generator +from .messages import ( + ConstantPublisher, + expect_message, + messages_equal, + publish_and_expect_message, +) +from .nodes import ( + NodeForTesting, + rclpy_context, + set_up_external_node, + set_up_external_nodes_from_launchnode, + set_up_node, +) +from .resources import MessageResource, NumpyResource, resource_path +from .threads import ContextThread, DynamicContextThread, get_unclosed_threads +from .transforms import set_up_static_transforms + +faulthandler.enable() + +__all__ = [ + "each_test_setup_teardown", + "get_unclosed_threads", + "ConstantPublisher", + "messages_equal", + "expect_message", + "exhaust_generator", + "publish_and_expect_message", + "MessageResource", + "NumpyResource", + "resource_path", + "ConfigurableServiceCallback", + "ContextThread", + "NodeForTesting", + "set_up_external_node", + "set_up_external_nodes_from_launchnode", + "set_up_node", + "rclpy_context", + "run_and_cancel_task", +] diff --git a/pkgs/node_helpers/node_helpers/testing/actions.py b/pkgs/node_helpers/node_helpers/testing/actions.py new file mode 100644 index 0000000..ce09151 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/actions.py @@ -0,0 +1,104 @@ +from collections.abc import Callable +from typing import Any +from unittest.mock import Mock +from uuid import UUID + +from action_msgs.msg import GoalStatus +from rclpy import Future + +from node_helpers.futures import wait_for_future + + +class MockClientGoalHandle(Mock): + """Plays the role of a ClientGoalHandle, but with controllable run times and result + values + """ + + def __init__(self) -> None: + super().__init__() + self._result_future = Future() + self.cancel_result: Any | None = None + + def get_result(self) -> Any: + return wait_for_future(self.get_result_async(), type_=object) + + def get_result_async(self) -> Future: + return self._result_future + + def cancel_goal(self) -> None: + self.cancel_goal_async() + + @property + def status(self) -> GoalStatus: + return GoalStatus.STATUS_SUCCEEDED + + def cancel_goal_async(self) -> Future: + action_is_already_finished = ( + self._result_future is not None and self._result_future.done() + ) + + if self.cancel_result is None and not action_is_already_finished: + raise RuntimeError( + "Goal was cancelled but no cancel result has been set. Either set a " + "result for the action, or set the cancel result." + ) + + self.set_result( + self._result_future.result() + if self.cancel_result is None + else self.cancel_result + ) + + future = Future() + # Cancel immediately + future.set_result(Mock()) + return future + + def set_result(self, result: Any) -> None: + """Marks the goal handle as complete with the given result""" + self._result_future.set_result(result) + + +class MockRobustActionClient(Mock): + """A RobustActionClient that produces MockClientGoalHandles when goals are sent""" + + def __init__(self) -> None: + super().__init__() + self.called_with: list[Any] = [] + self._result: Any | None = None + self._cancel_result: Any | None = None + + def send_goal(self, goal: Any) -> Any: + return self.send_goal_async(goal).result() + + def send_goal_async( + self, + goal: Any, + feedback_callback: Callable[[Any], None] | None = None, + goal_uuid: UUID | None = None, + ) -> Future: + self.called_with.append(goal) + client_goal_handle = MockClientGoalHandle() + if self._result is not None: + client_goal_handle.set_result(self._result) + if self._cancel_result is not None: + client_goal_handle.cancel_result = self._cancel_result + future = Future() + future.set_result(client_goal_handle) + return future + + def set_result(self, result: Any) -> None: + """Marks all future MockClientGoalHandles as complete with the given result + value + + :param result: The result to provide + """ + self._result = result + + def set_cancel_result(self, result: Any) -> None: + """Provides this result to all future MockClientGoalHandles that get cancelled. + MockClientGoalHandles may not be cancelled until they have a cancel result. + + :param result: The result to provide on cancellation + """ + self._cancel_result = result diff --git a/pkgs/node_helpers/node_helpers/testing/async_tools.py b/pkgs/node_helpers/node_helpers/testing/async_tools.py new file mode 100644 index 0000000..dc8ec06 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/async_tools.py @@ -0,0 +1,19 @@ +import asyncio +from collections.abc import AsyncGenerator, Coroutine +from contextlib import asynccontextmanager +from typing import Any + + +@asynccontextmanager +async def run_and_cancel_task( + coro: Coroutine[Any, Any, Any], +) -> AsyncGenerator[asyncio.Task[Any], None]: + task = asyncio.create_task(coro) + try: + yield task + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/pkgs/node_helpers/node_helpers/testing/callbacks.py b/pkgs/node_helpers/node_helpers/testing/callbacks.py new file mode 100644 index 0000000..5bbc765 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/callbacks.py @@ -0,0 +1,135 @@ +from collections.abc import Iterable +from threading import Event, RLock +from time import sleep +from typing import Any, Generic, TypeVar, cast + +from rclpy.action.server import ServerGoalHandle + +from node_helpers.pubsub import PublishEvent + +ReturnType = TypeVar("ReturnType") + + +class ConfigurableServiceCallback(Generic[ReturnType]): + """This class is for creating a service that returns preconfigured values each + iteration.""" + + def __init__(self, return_values: Iterable[ReturnType]): + self._return_value_iterator = iter(return_values) + self._iterator_lock = RLock() + self.call_count: int = 0 + + def set_return_values(self, return_values: Iterable[ReturnType]) -> None: + """This method is useful for tests that want to override the default retval""" + with self._iterator_lock: + self._return_value_iterator = iter(return_values) + + def __call__(self, *args: Any, **kwargs: Any) -> ReturnType: + with self._iterator_lock: + self.call_count += 1 + return next(self._return_value_iterator) + + +class ActionServerCallback(Generic[ReturnType]): + def __init__( + self, return_values: Iterable[ReturnType], feedback_values: Iterable[Any] = () + ) -> None: + """A helper for creating action callbacks that block, check cancellation, return + and otherwise allow control from the outside to emulate some behavior or + another. + + Instantiation example: + >>> action_callback = ActionServerCallback( + >>> (SomeType.Response(1), SomeType.Response(2)) + >>> ) + >>> node.create_robust_service( + >>> SomeType, + >>> "some_name" + >>> callback_group=MutuallyExclusiveCallbackGroup(), + >>> callback=action_callback, + >>> cancel_callback=lambda *args: CancelResponse.ACCEPT + >>> ) + + # Usage example, where testing if an action calls cancel on a child action: + >>> some_client.send_goal_async() + >>> assert action_callback.action_has_started.wait(30) + >>> some_client.cancel_goals_async() + >>> assert action_callback.on_cancel_requested.wait(30) + >>> # Do some other stuff, then allow the cancellation to continue + >>> some_client.allow_cancel.set() + + :param return_values: An iterable object of the action responses + :param feedback_values: An iterable object of the action feedback + """ + self.on_action_started = PublishEvent() + self.on_action_exiting = PublishEvent() + self.on_cancel_requested = PublishEvent() + + # Users of this action can set these events when they are ready for the action + # to finish. + self.allow_succeed = Event() + self.allow_abort = Event() + self.allow_cancel = Event() + + # Feedback related controls (will publish feedback once per event set) + self.allow_publish_feedback = Event() + + # This can be used to introspect the ServerGoalHandle + self.ongoing_goal_handle: ServerGoalHandle | None = None + + self._return_value_iterator = ConfigurableServiceCallback(return_values) + self._feedback_iterator = ConfigurableServiceCallback(feedback_values) + + @property + def call_count(self) -> int: + return self._return_value_iterator.call_count + + def set_return_values(self, return_values: Iterable[ReturnType]) -> None: + """Allows overriding the default action 'return' value list""" + self._return_value_iterator.set_return_values(return_values) + + def set_feedback_values(self, feedback_values: Iterable[Any]) -> None: + """Allows overriding the default action 'feedback' value list""" + self._feedback_iterator.set_return_values(feedback_values) + + @property + def ongoing_goal(self) -> Any: + return cast(ServerGoalHandle, self.ongoing_goal_handle).request + + def __call__(self, goal: ServerGoalHandle) -> ReturnType: + # Track the current goal handle, for test writing + if self.ongoing_goal_handle is not None: + raise RuntimeError( + "ActionServerCallback isn't designed for multithreaded use!" + ) + self.ongoing_goal_handle = goal + + # Now that 'ongoing_goal_handle' is tracked, mark the action as 'started' + self.on_action_started.set() + + # Wait for the action to be allowed to either finish or be cancelled + while True: + if self.allow_publish_feedback.is_set(): + feedback = self._feedback_iterator() + goal.publish_feedback(feedback) + + if self.allow_succeed.is_set(): + goal.succeed() + break + + if self.allow_abort.is_set(): + goal.abort() + break + + if goal.is_cancel_requested: + self.on_cancel_requested.set() + + if self.allow_cancel.is_set(): + goal.canceled() + break + sleep(0.1) + + self.on_action_exiting.set() + self.ongoing_goal_handle = None + + return self._return_value_iterator() diff --git a/pkgs/node_helpers/node_helpers/testing/fixtures.py b/pkgs/node_helpers/node_helpers/testing/fixtures.py new file mode 100644 index 0000000..00bbc5c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/fixtures.py @@ -0,0 +1,27 @@ +from collections.abc import Generator + +import pytest + +from .threads import get_unclosed_threads + + +@pytest.fixture(autouse=True) +def each_test_setup_teardown() -> Generator[None, None, None]: # noqa: PT004 + """This function should be imported in the root `conftest.py` of every package. + + It will validate that all threads have been closed inbetween each test, and end the + testing session if not. + + :yields: Nothing. There's no need to depend on this fixture as long as it's imported + """ + yield + + # Any teardown for all tests goes here + unclosed_threads = get_unclosed_threads() + if len(unclosed_threads): + msg = ( + "There were unclosed threads after fixture teardown: " + f"{unclosed_threads=}.\n The testing session will end unfinished." + ) + + pytest.exit(msg, returncode=-1) diff --git a/pkgs/node_helpers/node_helpers/testing/generators.py b/pkgs/node_helpers/node_helpers/testing/generators.py new file mode 100644 index 0000000..cd708ac --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/generators.py @@ -0,0 +1,21 @@ +from collections.abc import Generator +from typing import Any, TypeVar + +RET_VAL = TypeVar("RET_VAL") +YIELD_VAL = TypeVar("YIELD_VAL") +INPUT_VAL = TypeVar("INPUT_VAL") + + +def exhaust_generator( + generator: Generator[YIELD_VAL, Any, RET_VAL], +) -> tuple[tuple[YIELD_VAL, ...], RET_VAL]: + """Exhaust a generator and return the yielded values and the return value""" + yielded: list[YIELD_VAL] = [] + + try: + while True: + yielded.append(next(generator)) + except StopIteration as e: + returned = e.value + + return tuple(yielded), returned diff --git a/pkgs/node_helpers/node_helpers/testing/messages.py b/pkgs/node_helpers/node_helpers/testing/messages.py new file mode 100644 index 0000000..eb60d0e --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/messages.py @@ -0,0 +1,140 @@ +import math +import queue +from collections.abc import Callable, Iterable +from time import sleep +from typing import Any, TypeAlias, TypeVar, cast + +from rclpy.publisher import Publisher +from sensor_msgs.msg import JointState +from std_msgs.msg import Float64MultiArray + +from .threads import ContextThread + + +class ConstantPublisher(ContextThread): + """Constantly publishes a message to a publisher. This class should be used as a + context manager. + """ + + def __init__(self, publisher: Publisher, outgoing: Any) -> None: + """ + :param publisher: The publisher to send messages on + :param outgoing: The message to send + """ + super().__init__("Constant Publisher") + + self._publisher = publisher + self._outgoing = outgoing + + def run(self) -> None: + while self.running: + self._publisher.publish(self._outgoing) + sleep(0.01) + + +_EQUALITY_MSG: TypeAlias = Float64MultiArray | JointState + +T = TypeVar("T", bound=_EQUALITY_MSG) + + +def messages_equal(actual: T, expected: T) -> bool: + """Checks if the given messages are equal""" + is_equal = _get_equality_func(type(expected)) + return is_equal(actual, expected) + + +def expect_message( + messages: "queue.Queue[T]", + expected: T, + max_tries: int = 30, +) -> None: + """Continuously publishes the given message and checks for the + desired result on the from_firmware queue. + + :param messages: A queue where messages are put + :param expected: The expected value to come from messages + :param max_tries: The maximum number of from_firmware messages + to read before giving up + :raises AssertionError: If the expected value was not received + """ + messages.queue.clear() + + expected_count = 0 + msg = None + + for _ in range(max_tries): + msg = messages.get(timeout=5) + if messages_equal(msg, expected): + expected_count += 1 + if expected_count >= 5: + return + + fail_message = f"Messages failed to stabilize at value {expected}" + if msg is not None: + fail_message += f", latest value is {msg}" + + raise AssertionError(fail_message) + + +def publish_and_expect_message( + publisher: Publisher, + outgoing: Any, + messages: "queue.Queue[_EQUALITY_MSG]", + expected: _EQUALITY_MSG, + max_tries: int = 30, +) -> None: + """Continuously publishes the given message and checks for the desired result on the + from_firmware queue. + + :param publisher: The publisher to publish the outgoing message to + :param outgoing: The outgoing message to publish + :param messages: The from_firmware message queue + :param expected: The expected value to come from from_firmware + :param max_tries: The maximum number of from_firmware messages to read + before giving up + """ + with ConstantPublisher(publisher, outgoing): + expect_message(messages, expected, max_tries) + + +_EQUALITY_FUNC = Callable[[_EQUALITY_MSG, _EQUALITY_MSG], bool] + +_EQUALITY_FUNCS: dict[type[_EQUALITY_MSG], _EQUALITY_FUNC] = {} + + +def _float64_multi_array(a: Float64MultiArray, b: Float64MultiArray) -> bool: + return bool(a.data.tolist() == b.data.tolist()) + + +def _joint_state(a: JointState, b: JointState) -> bool: + def all_are_close(values: Iterable[tuple[float, float]]) -> bool: + return all(math.isclose(val1, val2, rel_tol=0.001) for val1, val2 in values) + + return ( + a.name == b.name + and all_are_close(zip(a.position, b.position, strict=False)) + and all_are_close(zip(a.velocity, b.velocity, strict=False)) + and all_are_close(zip(a.effort, b.effort, strict=False)) + ) + + +_EQUALITY_FUNCS[Float64MultiArray] = _float64_multi_array + +_EQUALITY_FUNCS[JointState] = _joint_state + + +def _get_equality_func(msg_type: type[_EQUALITY_MSG]) -> _EQUALITY_FUNC: + try: + return _EQUALITY_FUNCS[msg_type] + except KeyError: + # Use default equality testing + return lambda a, b: cast(bool, a == b) + + +def _flush_queue(q: "queue.Queue[Any]") -> None: + """Empties the queue of existing items""" + while True: + try: + q.get_nowait() + except queue.Empty: + break diff --git a/pkgs/node_helpers/node_helpers/testing/nodes.py b/pkgs/node_helpers/node_helpers/testing/nodes.py new file mode 100644 index 0000000..962a0a4 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/nodes.py @@ -0,0 +1,340 @@ +import contextlib +import logging +from collections.abc import Callable, Generator +from functools import partial +from multiprocessing import Process +from pathlib import Path +from queue import Queue +from threading import Thread +from time import sleep +from typing import TypeVar + +import psutil +import rclpy +from launch import LaunchDescription, LaunchService +from launch_ros.actions import Node as LaunchNode +from launch_ros.parameter_descriptions import Parameter as LaunchParameter +from rclpy.context import Context +from rclpy.executors import ExternalShutdownException, SingleThreadedExecutor +from rclpy.node import Node +from rclpy.parameter import Parameter +from rclpy.qos import QoSProfile, qos_profile_services_default +from rclpy.task import Future + +from node_helpers.nodes import HelpfulNode +from node_helpers.parameters import Namespace, ParameterLoader +from node_helpers.spinning import ( + DEFAULT_MULTITHREADED_THREAD_COUNT, + MultiThreadedStackTracedExecutor, +) +from node_helpers.testing import each_test_setup_teardown # noqa: F401 +from node_helpers.timing import TestingTimeout as Timeout + +MESSAGE_TYPE = TypeVar("MESSAGE_TYPE") + + +class NodeForTesting(HelpfulNode): + """A ROS node, augmented with methods that are useful for testing""" + + def create_queue_subscription( + self, + type_: type[MESSAGE_TYPE], + topic: str, + *, + qos_profile: QoSProfile = qos_profile_services_default, + ) -> "Queue[MESSAGE_TYPE]": + """Subscribes to a topic, putting the results in a queue. + + :param type_: The type of message this topic uses + :param topic: The topic to subscribe to + :param qos_profile: The QoS profile to use for the subscription + :return: The queue where values will be put + """ + messages: "Queue[MESSAGE_TYPE]" = Queue() + + self.create_subscription( + type_, + topic, + messages.put, + qos_profile=qos_profile, + ) + + return messages + + def create_subscription_future( + self, + msg_type: type[MESSAGE_TYPE], + topic: str, + qos_profile: QoSProfile = qos_profile_services_default, + ) -> Future: + """Returns a future that finishes when a message is received on a topic""" + future = Future() + + self.create_subscription( + msg_type=msg_type, + topic=topic, + qos_profile=qos_profile, + callback=future.set_result, + ) + + return future + + +NODE_CLASS = TypeVar("NODE_CLASS", bound=Node) + +_INITIALIZED_TIMEOUT = 5 + + +@contextlib.contextmanager +def rclpy_context() -> Generator[Context, None, None]: + """A context manager for initializing a ROS context then cleaning up afterward""" + context = Context() + rclpy.init(context=context) + + try: + yield context + finally: + rclpy.shutdown(context=context, uninstall_handlers=True) + + +def set_up_node( + node_class: type[NODE_CLASS], + namespace: str, + node_name: str, + *, + default_params_directory: Path | str | None = None, + parameter_overrides: list[Parameter] | None = None, + remappings: dict[str, str] | None = None, + multi_threaded: bool = False, +) -> Generator[NODE_CLASS, None, None]: + """Sets up the given node class in a new ROS context with its own executor. + Doing it this way better simulates runtime behavior, where each node is its own + separate process. It also prevents some types of deadlocks that would not exist + outside of testing. + + :param node_class: The class of the node to set up + :param namespace: A namespace to put the node in + :param node_name: The name to give the node + :param default_params_directory: A directory to load parameters from. This is used + when it's convenient to load parameters from yaml files, instead of specifying + them in code. Yaml files will be loaded and overrided in alphanumeric order. + Only the parameters for the given {namespace}.{node_name} will be loaded. + :param parameter_overrides: Parameters values to provide to the node, overriding + their values in configuration files + :param remappings: Remappings for topics and other names, where the key is the + name to be remapped, and the value is the new name + :param multi_threaded: If True, the node will be started with a multi-threaded + executor + :yields: A node instance, ready to go + """ + + namespace_obj = Namespace([]) + namespace_obj += Namespace.from_string(namespace) + namespace_obj += Namespace.from_string(node_name) + + # Optionally load parameters from a directory + params = [] + if default_params_directory is not None: + base_config_path = _find_path_in_parents(default_params_directory) + params = ParameterLoader( + parameters_directory=base_config_path + ).parameters_for_node(namespace_obj) + + if parameter_overrides is not None: + params = _merge_parameters(params, parameter_overrides) + + if remappings is None: + remappings = {} + + cli_args = ["--ros-args"] + for map_from, map_to in remappings.items(): + cli_args += ["--remap", f"{map_from}:={map_to}"] + + with rclpy_context() as context: + if multi_threaded: + executor = MultiThreadedStackTracedExecutor( + context=context, num_threads=DEFAULT_MULTITHREADED_THREAD_COUNT + ) + else: + executor = SingleThreadedExecutor(context=context) + + node = node_class( + context=context, + namespace=namespace, + parameter_overrides=params, + cli_args=cli_args, + ) + + def spin() -> None: + try: + rclpy.spin(node, executor) + except ExternalShutdownException: + pass # This exception always happens when shutting down the executor + + spin_node_thread = Thread( + target=spin, + name=f"Node Spin {node_class.__name__}", + daemon=True, + ) + spin_node_thread.start() + + # Wait for rclpy.spin to set the executor + while node.executor is None: + sleep(0.01) + + yield node + + # Explicitly destroy the node, instead of letting garbage collection do it later. + # This ensures that node-specific cleanup operations happen in advance of other + # shutdown procedures. + node.destroy_node() + assert executor.shutdown(timeout_sec=30), "Failed to shut down executor!" + spin_node_thread.join(timeout=30) + assert not spin_node_thread.is_alive() + + +def set_up_external_node( + package_name: str, + executable: str, + *, + is_ready: Callable[[Process], bool] | None = None, + ready_timeout: float = _INITIALIZED_TIMEOUT, + parameters: list[Parameter] | None = None, + remappings: dict[str, str] | None = None, + namespace: str | None = None, +) -> Generator[Process, None, None]: + """Starts a node that is defined outside our codebase. + + :param package_name: The package the namespace is provided in + :param executable: The name of the node executable + :param is_ready: A function that can be called to check if the node is ready + :param ready_timeout: The maximum amount of time to wait on the node to be ready + :param parameters: Parameter values to assign to the node + :param remappings: Remappings for topics and other names, where the key is the + name to be remapped, and the value is the new name + :param namespace: A namespace to move the node to. If this is not provided, + the node is launched in the root namespace. This namespace must be prefixed + with a forward slash, indicating that it is relative to the root namespace. + :yields: The running process + """ + + # Reformat arguments to match the weird formats launch_ros uses + remappings_launch = None + if remappings is not None: + remappings_launch = ((k, v) for k, v in remappings.items()) + parameters_launch = None + if parameters is not None: + parameters_launch = [LaunchParameter(p.name, p.value) for p in parameters] + + launch_node = LaunchNode( + package=package_name, + executable=executable, + namespace=namespace, + remappings=remappings_launch, + parameters=parameters_launch, + ) + yield from set_up_external_nodes_from_launchnode( + [launch_node], is_ready=is_ready, ready_timeout=ready_timeout + ) + + +def set_up_external_nodes_from_launchnode( + nodes: list[LaunchNode], + is_ready: Callable[[Process], bool] | None = None, + ready_timeout: float = _INITIALIZED_TIMEOUT, +) -> Generator[Process, None, None]: + """Run an already prepared launch.Node in a new process""" + + launch_desc = LaunchDescription(nodes) + launch_service = LaunchService() + launch_service.include_launch_description(launch_desc) + + proc = Process( + name=f"Launch Service for {nodes=}", + target=partial(launch_service.run, shutdown_when_idle=False), + daemon=True, + ) + proc.start() + + if is_ready is not None: + timeout = Timeout( + ready_timeout, + timeout_message=f"Node {nodes} took too long to become ready", + ) + while not is_ready(proc) and timeout: + sleep(0.01) + + yield proc + + # Ensure that the process did not exit early + assert proc.is_alive() + assert proc.pid is not None + + # Start killing all the children in the process group, wait, then kill the parent + children_to_kill = psutil.Process(proc.pid).children(recursive=True) + for child in children_to_kill: + child.kill() + + for child in children_to_kill: + try: + child.wait(30) + except Exception: + logging.exception("Child process failed to exit") + + assert not child.is_running() + + proc.kill() + + # Block until the process exits + proc.join(10) + assert not proc.is_alive() + assert proc.exitcode is not None + proc.close() + + +def _find_path_in_parents(path: Path | str) -> Path: + """Travels up the directory tree until the provided path can be found""" + for parent in Path.cwd().parents: + if (parent / path).exists(): + return parent / path + + raise RuntimeError( + f"Could not find {path} in {Path.cwd()} or any of its parent folders" + ) + + +def _parameter_to_string(param: Parameter) -> str: + if isinstance(param.value, list): + # Lists take on the format [val1,val2,val3] + output = "[" + values_str = [str(v) for v in param.value] + if param.type_ is Parameter.Type.STRING_ARRAY: + # Wrap the string in quotes to avoid parsing issues + values_str = [f"'{v}'" for v in values_str] + output += ",".join(values_str) + output += "]" + elif isinstance(param.value, str): + # Wrap the string in quotes to avoid parsing issues if the string has special + # characters in it + output = f"'{param.value}'" + else: + output = str(param.value) + + return output + + +def _merge_parameters( + base: list[Parameter], override: list[Parameter] +) -> list[Parameter]: + """Combines the two lists of parameters, using the values from the second list in + the case of parameter name collisions + + :param base: The base parameter list + :param override: The parameters to apply over the base list + :return: Combined parameters + """ + + base_dict = {v.name: v for v in base} + override_dict = {v.name: v for v in override} + base_dict.update(override_dict) + return list(base_dict.values()) diff --git a/pkgs/node_helpers/node_helpers/testing/resources/__init__.py b/pkgs/node_helpers/node_helpers/testing/resources/__init__.py new file mode 100644 index 0000000..833a880 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/resources/__init__.py @@ -0,0 +1,9 @@ +from .message_resource import MessageResource +from .numpy_resource import NumpyResource +from .resource_paths import resource_path + +__all__ = [ + "resource_path", + "NumpyResource", + "MessageResource", +] diff --git a/pkgs/node_helpers/node_helpers/testing/resources/message_resource.py b/pkgs/node_helpers/node_helpers/testing/resources/message_resource.py new file mode 100644 index 0000000..44bb146 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/resources/message_resource.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Generic, TypeVar + +from rclpy.serialization import deserialize_message + +from .resource_paths import resource_path + +MsgType = TypeVar("MsgType") + + +class MessageResource(Generic[MsgType]): + """A helper class for loading serialized ROS messages""" + + def __init__(self, *paths: str | Path, msg_type: type[MsgType]): + self.path = resource_path(*paths) + self.msg_type = msg_type + + @property + def msg(self) -> MsgType: + msg_bytes = self.path.read_bytes() + msg: MsgType = deserialize_message(msg_bytes, self.msg_type) + return msg diff --git a/pkgs/node_helpers/node_helpers/testing/resources/numpy_resource.py b/pkgs/node_helpers/node_helpers/testing/resources/numpy_resource.py new file mode 100644 index 0000000..1223197 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/resources/numpy_resource.py @@ -0,0 +1,19 @@ +from pathlib import Path + +import numpy as np +import numpy.typing as npt + +from .resource_paths import resource_path + + +class NumpyResource: + """A helper class for loading *.npy files""" + + def __init__(self, *paths: str | Path): + self.path = resource_path(*paths) + + @property + def array(self) -> npt.DTypeLike: + with self.path.open("rb") as npy_file: + array: npt.DTypeLike = np.load(npy_file) + return array diff --git a/pkgs/node_helpers/node_helpers/testing/resources/resource_paths.py b/pkgs/node_helpers/node_helpers/testing/resources/resource_paths.py new file mode 100644 index 0000000..f173796 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/resources/resource_paths.py @@ -0,0 +1,5 @@ +from pathlib import Path + + +def resource_path(*paths: str | Path) -> Path: + return Path(*paths).resolve(strict=True) diff --git a/pkgs/node_helpers/node_helpers/testing/threads.py b/pkgs/node_helpers/node_helpers/testing/threads.py new file mode 100644 index 0000000..745e3ca --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/threads.py @@ -0,0 +1,101 @@ +import re +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import partial +from types import TracebackType +from typing import Any, TypeVar + +T = TypeVar("T", bound="ContextThread") + + +class ContextThread(ABC): + """Manages the lifetime of a thread using a context manager""" + + def __init__(self, name: str) -> None: + self.running = False + self.exception: Exception | None = None + """If an exception is raised, it'll be re-raised upon exiting the context""" + + self._thread = threading.Thread( + name=name, + target=self._record_exceptions, + daemon=True, + ) + + @abstractmethod + def run(self) -> None: + """The behavior to run in the thread""" + + def _record_exceptions(self) -> None: + try: + self.run() + except Exception as e: # noqa: BLE001 + self.exception = e + + def __enter__(self: T) -> T: + self.running = True + self.exception = None + self._thread.start() + return self + + def __exit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> None: + self.running = False + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + raise RuntimeError(f"Thread {self._thread.name} did not stop") + + if self.exception: + raise self.exception + + +class DynamicContextThread(ContextThread): + """A ContextThread that lets you pass a function and arguments on initialization + + Mirrors the Thread(...) initialization API + """ + + def __init__(self, target: Callable[..., Any], *args: Any, **kwargs: Any): + super().__init__(name=f"{target.__name__} DynamicContextThread") + self._target = partial(target, *args, **kwargs) + + def run(self) -> None: + self._target() + + +def get_unclosed_threads(allowable_threads: list[str] | None = None) -> list[str]: + """A convenient function that returns threads that shouldn't be alive inbetween + tests. + + :param allowable_threads: A list of regular expressions that match to + thread names that are allowed to stay live inbetween tests + :returns: A list of thread names that didn't match the allowable thread regexes + """ + if allowable_threads is None: + allowable_threads = [] + allowable_threads += [ + r"pydevd\.Writer", + r"pydevd\.Reader", + r"pydevd\.CommandThread", + r"profiler\.Reader", + r"MainThread", + ] + + open_threads = [] + + for thread in threading.enumerate(): + matched = False + for name_pattern in allowable_threads: + if re.match(name_pattern, thread.name): + matched = True + break + + if not matched: + open_threads.append(thread) + + return [o.name for o in open_threads] diff --git a/pkgs/node_helpers/node_helpers/testing/timing.py b/pkgs/node_helpers/node_helpers/testing/timing.py new file mode 100644 index 0000000..5cf7abd --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/timing.py @@ -0,0 +1,25 @@ +from node_helpers.timing import Timeout + +"""Handy mock for Timeout that is always expired. +Helpful for testing modules that sleep or wait for timeouts.""" + + +class MockTimeout(Timeout): + def __init__( + self, + seconds: float, + raise_error: bool = False, + timeout_message: str = "Timeout", + default_state: bool = False, # Always expired by default + ) -> None: + self.seconds = seconds + self.mock_timeout_state = default_state + + def reset_seconds(self, seconds: float) -> None: + self.seconds = seconds + + def __bool__(self) -> bool: + return self.mock_timeout_state + + def set_active(self, active: bool) -> None: + self.mock_timeout_state = active diff --git a/pkgs/node_helpers/node_helpers/testing/transforms.py b/pkgs/node_helpers/node_helpers/testing/transforms.py new file mode 100644 index 0000000..28e3dcb --- /dev/null +++ b/pkgs/node_helpers/node_helpers/testing/transforms.py @@ -0,0 +1,48 @@ +from collections.abc import Generator +from pathlib import Path +from tempfile import TemporaryDirectory + +from rclpy import Parameter + +from node_helpers.nodes import InteractiveTransformPublisher +from node_helpers.testing import set_up_node + + +def set_up_static_transforms( + *parents_to_children: tuple[str, str], namespace: str = "calibration" +) -> Generator[InteractiveTransformPublisher, None, None]: + """This function can be used to create and publish static transforms in a fixture. + It operates much like set_up_node. Here's how to use it: + + >>> @pytest.fixture + >>> def world_wood_robot_transforms(): + >>> yield from set_up_static_transforms( + >>> ("world", "base_link"), + >>> ("base_link", "wood"), + >>> ("base_link", "robot"), + >>> ) + + The above example would create static transforms between each parent and child. + + :param parents_to_children: The parent -> child TF pairs to create + :param namespace: The namespace to use for the node + :yields: A configured, launched InteractiveTransformPublisher node + """ + parents_to_children = ( + parents_to_children if parents_to_children else (("world", "base_link"),) + ) + transforms_str = [f"{parent}:{child}" for parent, child in parents_to_children] + + with TemporaryDirectory() as tempdir: + temp_config = Path(tempdir) / "transforms.json" + yield from set_up_node( + node_class=InteractiveTransformPublisher, + namespace=namespace, + multi_threaded=True, + node_name="static_transforms", + parameter_overrides=[ + Parameter(name="static_transforms_file", value=str(temp_config)), + Parameter(name="tf_publish_frequency", value=2.0), + Parameter(name="transforms", value=transforms_str), + ], + ) diff --git a/pkgs/node_helpers/node_helpers/tf/README.md b/pkgs/node_helpers/node_helpers/tf/README.md new file mode 100644 index 0000000..5924a41 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/tf/README.md @@ -0,0 +1,3 @@ +# node_helpers.tf + +This module provides tools for working with ROS2 Transforms. It includes utilities for broadcasting, manipulating, and validating transformations in dynamic and static contexts. diff --git a/pkgs/node_helpers/node_helpers/tf/__init__.py b/pkgs/node_helpers/node_helpers/tf/__init__.py new file mode 100644 index 0000000..5ab8019 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/tf/__init__.py @@ -0,0 +1,2 @@ +from .constant_broadcaster import ConstantStaticTransformBroadcaster +from .movement import block_until_tfs_are_static, tf_velocity diff --git a/pkgs/node_helpers/node_helpers/tf/constant_broadcaster.py b/pkgs/node_helpers/node_helpers/tf/constant_broadcaster.py new file mode 100644 index 0000000..f8926f2 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/tf/constant_broadcaster.py @@ -0,0 +1,47 @@ +from geometry_msgs.msg import TransformStamped +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup +from rclpy.node import Node +from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile +from tf2_msgs.msg import TFMessage + + +class ConstantStaticTransformBroadcaster: + """This object sets up a ros timer to repeatedly re-broadcast a static transform. + + It's an analog of StaticTransformBroadcaster, but it's designed to periodically + re-broadcast the same transform, rather than sending it once. This is after some + observations that QoS latching behavior wasn't 100% reliable in production... + """ + + def __init__( + self, + node: Node, + initial_transform: TransformStamped | None = None, + publish_seconds: float = 1, + ): + """ + :param node: The node to create the timer and publish with + :param initial_transform: Optional, the transform to publish initially + :param publish_seconds: How many seconds inbetween repeat publishes + """ + qos = QoSProfile( + depth=1, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_LAST, + ) + self.pub_tf = node.create_publisher(TFMessage, "/tf_static", qos) + + self._net_message = TFMessage(transforms=[initial_transform]) + node.create_timer( + publish_seconds, + callback=self._publish_transform, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + + def _publish_transform(self) -> None: + self.pub_tf.publish(self._net_message) + + def set_transform(self, transform: TransformStamped) -> None: + """Update the transform that is constantly being broadcasted""" + self._net_message.transforms = [transform] + self._publish_transform() diff --git a/pkgs/node_helpers/node_helpers/tf/movement.py b/pkgs/node_helpers/node_helpers/tf/movement.py new file mode 100644 index 0000000..a78b7f3 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/tf/movement.py @@ -0,0 +1,96 @@ +from time import time + +import numpy as np +import numpy.typing as npt +from builtin_interfaces.msg import Time +from rclpy.duration import Duration +from rclpy.time import Time as RclpyTime +from tf2_ros import Buffer + +from node_helpers.ros2_numpy import numpify +from node_helpers.tf import timestamps +from node_helpers.timing import Timeout + + +def tf_velocity( + frame_id: str, + tf_buffer: Buffer, + at_time: Time, + sample_duration: float = 0.005, + relative_to_frame: str = "base_link", + tf_timeout: int | None = 10, +) -> npt.NDArray[np.float64]: + """Returns the velocity in meters per second along the X, Y Z axis of the given + frame_id. + + :param frame_id: The frame_id to get the velocity of + :param tf_buffer: The tf buffer to use + :param at_time: The time to get the velocity at + :param sample_duration: The duration to sample the velocity over + :param relative_to_frame: The frame to get the velocity relative to + :param tf_timeout: The timeout to wait for the transform + :return: The (X, Y, Z) velicities in meters per second + :raises ValueError: If the at_time is not a valid timestamp + """ + if at_time == Time(): + raise ValueError("This function must have a valid timestamp!") + + tf_timeout = tf_timeout or None + + past_time = RclpyTime.from_msg(at_time) - Duration(seconds=sample_duration) + vector = tf_buffer.lookup_transform_full( + target_frame=frame_id, + target_time=past_time, + source_frame=frame_id, + source_time=at_time, + fixed_frame=relative_to_frame, + timeout=Duration(seconds=tf_timeout), + ).transform.translation + + # Convert to velocity in meters per second along the wood axis (x) + return numpify(vector) / sample_duration + + +def block_until_tfs_are_static( + frame_ids: list[str], + tf_buffer: Buffer, + start_time: float | None = None, + timeout: int | None = 10, + time_increment: float = 0.01, + velocity_tolerance: float = 0.0001, +) -> None: + """Wait until each of the given TFs are stable and no longer moving + + :param frame_ids: The frame_ids to wait for 0 velocity + :param tf_buffer: The tf buffer to use + :param start_time: The minimum time to start waiting for static tfs. If unset, + the current time is used. + :param timeout: The maximum time to wait for the tfs to become static. + :param time_increment: The time increment to use when checking for static tfs. + :param velocity_tolerance: The "static" tolerance to use when checking for tfs. + """ + velocity_at_time = start_time or time() + msg = f"Timed out while waiting for frames to be static: {frame_ids}" + retry_timeout = ( + Timeout(timeout, raise_error=True, timeout_message=msg) if timeout else True + ) + while retry_timeout: + velocities = [] + for frame_id in frame_ids: + velocity = tf_velocity( + frame_id=frame_id, + tf_buffer=tf_buffer, + at_time=timestamps.unix_timestamp_to_ros(velocity_at_time), + sample_duration=0.005, + tf_timeout=timeout, + ) + velocities.append(round(np.linalg.norm(velocity), 6)) + + if all( + np.linalg.norm(velocity) <= velocity_tolerance for velocity in velocities + ): + break + + # Since velocities weren't stable at velocity_at_time, try again but with a time + # increment + velocity_at_time += time_increment diff --git a/pkgs/node_helpers/node_helpers/tf/timestamps.py b/pkgs/node_helpers/node_helpers/tf/timestamps.py new file mode 100644 index 0000000..05627ca --- /dev/null +++ b/pkgs/node_helpers/node_helpers/tf/timestamps.py @@ -0,0 +1,29 @@ +from typing import cast + +from builtin_interfaces.msg import Time +from rclpy.time import Time as RclpyTime + +CONVERSION_CONSTANT = 10**9 + + +def ros_stamp_to_unix_timestamp(stamp: Time | RclpyTime) -> float: + """Convert a ros2 Time message to a unix timestamp float""" + rclpy_msg = RclpyTime.from_msg(stamp) if isinstance(stamp, Time) else stamp + return cast(float, rclpy_msg.nanoseconds / CONVERSION_CONSTANT) + + +def unix_timestamp_to_ros(stamp: float) -> Time: + """Convert a unix timestamp to a ros2 Time message""" + seconds = int(stamp) + nano_seconds = int((stamp - seconds) * CONVERSION_CONSTANT) + return Time(sec=seconds, nanosec=nano_seconds) + + +def is_newer(a: Time, b: Time) -> bool: + """Check if a > b""" + return cast(bool, a.sec > b.sec or (a.sec == b.sec and a.nanosec > b.nanosec)) + + +def is_older(a: Time, b: Time) -> bool: + """Check if a < b""" + return cast(bool, a.sec < b.sec or (a.sec == b.sec and a.nanosec < b.nanosec)) diff --git a/pkgs/node_helpers/node_helpers/timing/README.md b/pkgs/node_helpers/node_helpers/timing/README.md new file mode 100644 index 0000000..e6506b1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/README.md @@ -0,0 +1,10 @@ +# node_helpers.timing + +The `node_helpers.timing` module provides utilities for managing time in ROS2, including context-based timing and caching. It includes various tools tailored for specific use cases: + +- **Timers**: Use `Timer` for profiling and `WarningTimer` to detect slow callbacks. +- **Mixins**: `SingleShotMixin` simplifies creating one-time timers, while `TimerWithWarningsMixin` logs warnings if timers fall behind. +- **Timeouts**: `Timeout` enforces time limits, and `TestingTimeout` is ideal for test cases. +- **TTL Caching**: The `ttl_cached` decorator enables caching method results for a set duration. + +These tools streamline time management and enhance system reliability. diff --git a/pkgs/node_helpers/node_helpers/timing/__init__.py b/pkgs/node_helpers/node_helpers/timing/__init__.py new file mode 100644 index 0000000..28c932c --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/__init__.py @@ -0,0 +1,5 @@ +from .single_shots_mixin import SingleShotMixin +from .timeout import TestingTimeout, Timeout, has_timed_out +from .timer import Timer, WarningTimer +from .timer_with_warnings_mixin import TimerWithWarningsMixin +from .ttl_caches import ttl_cached diff --git a/pkgs/node_helpers/node_helpers/timing/single_shots_mixin.py b/pkgs/node_helpers/node_helpers/timing/single_shots_mixin.py new file mode 100644 index 0000000..a1c8955 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/single_shots_mixin.py @@ -0,0 +1,51 @@ +from collections.abc import Callable + +from rclpy.callback_groups import CallbackGroup +from rclpy.clock import Clock +from rclpy.node import Node +from rclpy.timer import Timer + + +class _TimerCallback: + def __init__(self, callback: Callable[[], None]): + self._timer: Timer = None # Is assigned right after creation + self._callback = callback + + def __call__(self) -> None: + # Cancel and destroy the timer, then remove references to it + self._timer.cancel() + self._timer.destroy() + self._timer = None + + return self._callback() + + def assign_timer(self, timer: Timer) -> None: + """Called once the timer has been created with this callback""" + self._timer = timer + + +class SingleShotMixin: + """This mixin adds a method for creating a timer that will execute only once. + + The benefits of using this mixin is: + 1) It is tested and verified to not leave any references. + 2) Less boilerplate + """ + + def create_single_shot_timer( + self: Node, + timer_period_sec: float, + callback: Callable[[], None], + callback_group: CallbackGroup = None, + clock: Clock = None, + ) -> Timer: + timer_callback = _TimerCallback(callback) + timer = self.create_timer( + timer_period_sec=timer_period_sec, + callback=timer_callback, + callback_group=callback_group, + clock=clock, + ) + timer_callback.assign_timer(timer) + + return timer diff --git a/pkgs/node_helpers/node_helpers/timing/timeout.py b/pkgs/node_helpers/node_helpers/timing/timeout.py new file mode 100644 index 0000000..b5b68d9 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/timeout.py @@ -0,0 +1,65 @@ +from timeit import default_timer as timer + + +def has_timed_out(start: float, timeout: float) -> bool: + """Convenience method to return true if a period of time has elapsed + :param start: The start time of some arbitrary event + :param timeout: The time in seconds to check against + :return: True if the period has elapsed + """ + return timer() - start > timeout + + +class Timeout: + """This class is useful for checking if a time limit has finished. + It is useful for tests, or writing blocking calls that have the option to raise + TimeoutErrors. + Optionally raise a TimeoutError if the time limit has been reached. + + >>>timeout = Timeout(seconds=3) + >>>while timeout and other_condition(): + >>> pass + + If `other_condition()` never returns false, then the timeout object returns false + and optionally raises a `TimeoutError` after 3 seconds. + """ + + TIMEOUT_MESSAGE = "The Timeout of {seconds} seconds has been reached!" + + def __init__( + self, + seconds: float, + raise_error: bool = False, + timeout_message: str = TIMEOUT_MESSAGE, + ): + self.seconds = seconds + self._start = timer() + self._raise_error = raise_error + self._timeout_message = timeout_message + + def __bool__(self) -> bool: + if has_timed_out(self._start, self.seconds): + if self._raise_error: + raise TimeoutError(self._timeout_message.format(seconds=self.seconds)) + else: + return False + return True + + def reset(self) -> None: + """Reset the timer""" + self._start = timer() + + def reset_seconds(self, seconds: float) -> None: + """Reset the timer and set a new time limit. + This is useful for reusing the same object for changing time limits. + :param seconds: The new time limit""" + self.seconds = seconds + self.reset() + + +class TestingTimeout(Timeout): + """This class is useful for raising a TimeoutError once a time limit has finished. + It will always raise a TimeoutError, and is mainly used for tests""" + + def __init__(self, seconds: float, timeout_message: str = Timeout.TIMEOUT_MESSAGE): + super().__init__(seconds, raise_error=True, timeout_message=timeout_message) diff --git a/pkgs/node_helpers/node_helpers/timing/timer.py b/pkgs/node_helpers/node_helpers/timing/timer.py new file mode 100644 index 0000000..dbafee1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/timer.py @@ -0,0 +1,185 @@ +import logging +from collections import deque +from collections.abc import Callable +from functools import wraps +from timeit import default_timer as timer +from typing import Any, TypeVar, cast + +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + + +class Timer: + """Use this module when timing any section or function of code. + + Example 1: As a context manager + >>> timer = Timer(300) + >>> + >>> with timer: + >>> expensive_operation() + >>> other_operations() + >>> + >>>print(timer) + + Example 2: As a function decorator + >>>timer = Timer(10) + >>> + >>>@timer + >>>def cool_func(): + >>> expensive_operation() + >>> other_operations() + >>> + >>>cool_func() + >>>print(timer.fps) + + Example 3: Profiling multiple parts of a whole + >>>timer = Timer(15) + >>> + >>>with timer: + >>> # do things + >>> with timer.child("SpecificThing"): + >>> # do a specific thing + >>> + >>># This report will show the whole time and the times of the child as percentage + >>>print(timer) + The timeout object will return false after the alloted time, ending the + loop. This can be useful for tests. + """ + + _NO_SAMPLES_MSG = "No Samples!" + _REPORT_INDENT = "\t" + + def __init__(self, samples: int | None = 1, name: str = "Timer", log: bool = False): + """ + :param samples: The number of samples to extract the rolling average from + :param name: The name of the timer, to put in reports + :param log: Whether to log the timer report after it is finished + """ + self.name = name + self._log = log + self._num_samples = samples + self._samples: deque[float] = deque(maxlen=samples) + self._sample_start: float | None = None + self._children: dict[str, Timer] = {} + + def __call__(self, method: FuncT) -> FuncT: + """This implements the decorator functionality of the timer""" + + @wraps(method) + def on_call(*args: Any, **kwargs: Any) -> Any: + with self: + return method(*args, **kwargs) + + return cast(FuncT, on_call) + + def __repr__(self) -> str: + return self.create_report() + + def __enter__(self) -> "Timer": + self.begin() + return self + + def __exit__(self, *args: object) -> None: + self.end() + if self._log: + logging.info(f"{self}") + + @property + def running(self) -> bool: + """Return True if the timer is currently running""" + return self._sample_start is not None + + def begin(self) -> None: + """Start collection of a single timer sample""" + if self.running: + raise RuntimeError("The 'begin' method cannot be called twice in a row!") + + self._sample_start = timer() + + def end(self) -> None: + """Finish collection of a single timer sample""" + if not self.running: + raise RuntimeError("'end' was called before 'begin'!") + + self._samples.append(self.current_elapsed) + self._sample_start = None + + def create_report(self, depth: int = 0, parent_elapsed: float | None = None) -> str: + """Create a report for this timer""" + if len(self._samples): + elapsed = self.elapsed + if parent_elapsed is None: + parent_elapsed = elapsed + report = ( + f"{self.name.title()}(" + f"{round((elapsed / parent_elapsed) * 100, 2)}%, " + f"elapsed={round(elapsed, 2)}, " + f"fps={round(self.fps, 2)}, " + f"samples={len(self._samples)})" + ) + + # Iterate over children appending their reports + for child in self._children.values(): + report += "\n" + child.create_report(depth + 1, parent_elapsed=elapsed) + else: + # Prevent a ZeroDivisionError + report = f"{self.name.title()}({self._NO_SAMPLES_MSG})" + report = self._REPORT_INDENT * depth + report + + return report + + def child(self, name: str) -> "Timer": + """Creates a child Timer of the given name, if none exists, and returns it.""" + if not self.running: + raise RuntimeError( + "You cannot create a child outside of the parent timers context!" + ) + if name not in self._children: + self._children[name] = Timer(samples=self._num_samples, name=name) + return self._children[name] + + @property + def elapsed(self) -> float: + """Return the average elapsed time""" + if len(self._samples) == 0: + return 0.0 + return self.total_elapsed / len(self._samples) + + @property + def total_elapsed(self) -> float: + """Return the total time the timer spent active""" + return float(sum(self._samples)) + + @property + def fps(self) -> float: + """Return the average 'fps', or rather, 'operations per second'.""" + if len(self._samples) == 0: + return 0.0 + return 1 / self.elapsed + + def reset(self) -> None: + """Reset the timer, clearing all samples""" + self._samples.clear() + self._sample_start = None + + @property + def current_elapsed(self) -> float: + """Return the current elapsed time""" + if not self.running: + raise RuntimeError("You cannot collect a sample before starting the timer!") + return timer() - self._sample_start # type: ignore + + +class WarningTimer(Timer): + """A variant of Timer that logs if the timer Hz is below a certain threshold""" + + def __init__(self, name: str, target_hz: float, samples: int | None = 1): + super().__init__(samples=samples, name=name, log=False) + self.target_hz = target_hz + + def __exit__(self, *args: object) -> None: + super().__exit__(*args) + if self.fps < self.target_hz: + logging.warning( + f"{self.name} is below target Hz of {self.target_hz} with " + f"{self.fps:.2f} Hz" + ) diff --git a/pkgs/node_helpers/node_helpers/timing/timer_with_warnings_mixin.py b/pkgs/node_helpers/node_helpers/timing/timer_with_warnings_mixin.py new file mode 100644 index 0000000..7be239d --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/timer_with_warnings_mixin.py @@ -0,0 +1,35 @@ +from collections.abc import Callable + +from rclpy.callback_groups import CallbackGroup +from rclpy.clock import Clock +from rclpy.node import Node +from rclpy.timer import Timer + +from node_helpers.timing import WarningTimer + + +class TimerWithWarningsMixin: + """This mixin adds a method for creating a timer that will log a warning when the + timer is falling behind due to the callback taking too long. + """ + + def create_timer_with_warnings( + self: Node, + timer_period_sec: float, + callback: Callable[[], None], + name: str, + callback_group: CallbackGroup = None, + clock: Clock = None, + ) -> Timer: + warning_timer = WarningTimer(name, 1 / timer_period_sec) + + def wrapped() -> None: + with warning_timer: + callback() + + return self.create_timer( + timer_period_sec=timer_period_sec, + callback=wrapped, + callback_group=callback_group, + clock=clock, + ) diff --git a/pkgs/node_helpers/node_helpers/timing/ttl_caches.py b/pkgs/node_helpers/node_helpers/timing/ttl_caches.py new file mode 100644 index 0000000..1d7f7d6 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/timing/ttl_caches.py @@ -0,0 +1,91 @@ +import time +from collections.abc import Callable +from typing import Any, Generic, TypeVar, cast + +# Type variable for the method return type +ReturnType = TypeVar("ReturnType") +SomeCallable = Callable[..., ReturnType] + + +class TTLWrappedFuncType(Generic[ReturnType]): + """Type hint helper for the decorator""" + + cache: "TTLCache" + __call__: Callable[..., ReturnType] + + +class TTLCache: + def __init__(self, ttl_seconds: float): + self.ttl_seconds = ttl_seconds + self._cached_value = None + self._last_value_timestamp = 0.0 + self._is_set = False # Indicates whether the cache has been set + self._get_time = time.time + + def get(self) -> Any: + if ( + self._is_set + and self._get_time() - self._last_value_timestamp < self.ttl_seconds + ): + return True, self._cached_value + return False, None + + def set(self, value: Any) -> None: + self._cached_value = value + self._last_value_timestamp = self._get_time() + self._is_set = True + + +def ttl_cached( + seconds: float, +) -> Callable[[SomeCallable[ReturnType]], TTLWrappedFuncType[ReturnType]]: + """ + A function that adds time-to-live (TTL) caching behavior to another function. + The function's value is cached for a specified number of seconds. After the TTL + expires, the function is recomputed upon the next access. + + It should be noted that this impelmentation does not (yet) use weakrefs, so the + cache will hold on to the value until the next recomputation. + + :param seconds: The number of seconds the return value should be cached for. + :return: A decorator that transforms a class method into a TTL cached function. + + Usage: + >>> class MyClass: + ... @ttl_cached(seconds=10) + ... def my_function(self) -> int: + ... # Simulate a computation or database access + ... return int(time.time()) + ... + >>> obj = MyClass() + >>> value1 = obj.my_function() # This call computes the value + >>> value2 = obj.my_function() # This call returns the cached value + >>> time.sleep(10) + >>> value3 = obj.my_function() # This call computes a new value + + The ttl can also be edited, via + >>> obj.my_function.cache.ttl_seconds = 5 + """ + + def decorator( + func: SomeCallable[ReturnType], + ) -> TTLWrappedFuncType[ReturnType]: + cache = TTLCache(seconds) + + def wrapper(*args: Any, **kwargs: Any) -> ReturnType: + nonlocal cache + cached, cached_result = cache.get() + if cached: + return cast(ReturnType, cached_result) + else: + result = func(*args, **kwargs) + cache.set(result) + return result + + wrapper = cast(TTLWrappedFuncType[ReturnType], wrapper) + + # Attach the TTLCache instance to the wrapper function for later adjustments + wrapper.cache = cache + return wrapper + + return decorator diff --git a/pkgs/node_helpers/node_helpers/topics/README.md b/pkgs/node_helpers/node_helpers/topics/README.md new file mode 100644 index 0000000..c707230 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/topics/README.md @@ -0,0 +1,7 @@ +# node_helpers.topics + +This module provides tools for working with ROS2 topics, including utilities that augment publishing and subscribing. + +- **LatchingPublisher**: A publisher with behavior similar to ROS1's latching QoS. It ensures the latest message is consistently available to new subscribers by periodically republishing the last message. + +These tools help ensure robust and reliable topic communication in ROS2 systems. diff --git a/pkgs/node_helpers/node_helpers/topics/__init__.py b/pkgs/node_helpers/node_helpers/topics/__init__.py new file mode 100644 index 0000000..1ee283a --- /dev/null +++ b/pkgs/node_helpers/node_helpers/topics/__init__.py @@ -0,0 +1 @@ +from .latching_publisher import LatchingPublisher diff --git a/pkgs/node_helpers/node_helpers/topics/latching_publisher.py b/pkgs/node_helpers/node_helpers/topics/latching_publisher.py new file mode 100644 index 0000000..ec83e52 --- /dev/null +++ b/pkgs/node_helpers/node_helpers/topics/latching_publisher.py @@ -0,0 +1,49 @@ +from typing import Generic, TypeVar + +from rclpy.callback_groups import CallbackGroup +from rclpy.node import Node + +from node_helpers.qos import qos_latching + +T = TypeVar("T") + + +class LatchingPublisher(Generic[T]): + """Publishes a message using QoS behavior similar to ROS1's latching QoS. + The latest provided message is also routinely republished to ensure that + improperly configured subscribers (like roslibjs) will still receive + messages. + """ + + def __init__( + self, + node: Node, + msg_type: type[T], + topic: str, + republish_delay: float = 1.0, + *, + callback_group: CallbackGroup | None = None, + ): + self._publisher = node.create_publisher( + msg_type, topic, qos_profile=qos_latching, callback_group=callback_group + ) + self._last_msg: T | None = None + + # Routinely republish the message. This shouldn't be necessary since + # these topics are durable and transient local, but roslibjs will + # improperly configure its subscribers if it subscribes before this + # node has published. Sending routinely ensures that roslibjs will see + # messages on this topic even in that case. + node.create_timer(republish_delay, callback=self._republish) + + def __call__(self, msg: T) -> None: + self._last_msg = msg + self._publisher.publish(msg) + + def clear_msg_state(self) -> None: + """Used for deferring control of the latching publisher to another node""" + self._last_msg = None + + def _republish(self) -> None: + if self._last_msg is not None: + self._publisher.publish(self._last_msg) diff --git a/pkgs/node_helpers/node_helpers_test/integration/nodes/test_node_helpers_node.py b/pkgs/node_helpers/node_helpers_test/__init__.py similarity index 100% rename from pkgs/node_helpers/node_helpers_test/integration/nodes/test_node_helpers_node.py rename to pkgs/node_helpers/node_helpers_test/__init__.py diff --git a/pkgs/node_helpers/node_helpers_test/unit/nodes/test_node_helpers_node.py b/pkgs/node_helpers/node_helpers_test/integration/actions/__init__.py similarity index 100% rename from pkgs/node_helpers/node_helpers_test/unit/nodes/test_node_helpers_node.py rename to pkgs/node_helpers/node_helpers_test/integration/actions/__init__.py diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/server/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/actions/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_base_handler_and_worker.py b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_base_handler_and_worker.py new file mode 100644 index 0000000..59c2ccb --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_base_handler_and_worker.py @@ -0,0 +1,169 @@ +from collections.abc import Generator +from copy import deepcopy +from queue import Queue +from threading import Event +from typing import Any +from uuid import UUID + +import pytest +from node_helpers.actions.server import ActionCallMetric, ActionHandler +from node_helpers.futures import wait_for_future +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.testing import NodeForTesting, set_up_node +from node_helpers_msgs.action import RobustActionExample +from rclpy.action.client import ClientGoalHandle +from rclpy.action.server import ServerGoalHandle +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup + +from .utils import ActionWorkerForTesting, CoolError + + +class ExampleActionHandler( + ActionHandler[ + RobustActionExample.Goal, + RobustActionExample.Feedback, + RobustActionExample.Result, + ] +): + def __init__(self, node: HelpfulNode): + super().__init__( + node=node, + action_name="example_action", + action_type=RobustActionExample, + callback_group=MutuallyExclusiveCallbackGroup(), + ) + self.raise_exception = False + self.on_exception_called = Event() + self.done_event: Event | None = None + + def create_worker(self, goal_handle: ServerGoalHandle) -> ActionWorkerForTesting: + return ActionWorkerForTesting( + goal_handle, + done_event=self.done_event, + raise_exception=self.raise_exception, + on_exception_called=self.on_exception_called, + ) + + +class ExampleActionServer(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_action_server", **kwargs) + self.handler = ExampleActionHandler(self) + + +@pytest.fixture() +def example_action_server() -> Generator[ExampleActionServer, None, None]: + yield from set_up_node( + ExampleActionServer, + "action_handler", + "example_action_server", + multi_threaded=True, + ) + + +class ExampleActionClient(NodeForTesting): + def __init__(self, **kwargs: Any): + super().__init__("example_action_client", **kwargs) + self.client = self.create_robust_action_client( + RobustActionExample, "example_action" + ) + + +@pytest.fixture() +def example_action_client( + example_action_server: ExampleActionServer, +) -> Generator[ExampleActionClient, None, None]: + yield from set_up_node( + ExampleActionClient, + "action_handler", + "example_action_client", + multi_threaded=True, + ) + + +def test_basic_operation( + example_action_server: ExampleActionServer, + example_action_client: ExampleActionClient, +) -> None: + """Tests that the action server responds properly to a single goal""" + result = example_action_client.client.send_goal(RobustActionExample.Goal()) + assert result.result.data == "success" + + +def test_on_exception( + example_action_client: ExampleActionClient, + example_action_server: ExampleActionServer, +) -> None: + """Tests that the action server reports an error to the client when an exception is + raised + """ + example_action_server.handler.raise_exception = True + + with pytest.raises(RobustRPCException.like(CoolError)): + example_action_client.client.send_goal(RobustActionExample.Goal()) + + assert example_action_server.handler.on_exception_called.is_set() + + +@pytest.mark.parametrize("expected_result", ["success", "error", "canceled"]) +def test_report_metric( + example_action_client: ExampleActionClient, + example_action_server: ExampleActionServer, + expected_result: str, +) -> None: + """Tests that the action server sends metrics reporting information on the action + start time and final result. + """ + publish_queue: Queue[ActionCallMetric] = Queue() + + def on_metric(m: ActionCallMetric) -> None: + # The action server modifies metrics in-place, so we need to copy them to keep + # their original state + publish_queue.put(deepcopy(m)) + + example_action_server.handler._metrics_callback = on_metric + + # Set up the action to either error out or allow cancellation, when appropriate + if expected_result == "error": + example_action_server.handler.raise_exception = True + elif expected_result == "canceled": + # Make the action wait on an event so we have the chance to cancel it + example_action_server.handler.done_event = Event() + + goal_handle = wait_for_future( + example_action_client.client.send_goal_async(RobustActionExample.Goal()), + ClientGoalHandle, + timeout=10, + ) + + if expected_result == "canceled": + goal_handle.cancel_goal() + + if expected_result == "error": + with pytest.raises(RobustRPCException.like(CoolError)): + goal_handle.get_result() + else: + goal_handle.get_result() + + assert publish_queue.qsize() == 2 + + # Validate the first publish includes a goal_id/name/namespace + start_metric = publish_queue.get_nowait() + assert start_metric.result == "in_progress" + + # Validate the + final_metric = publish_queue.get_nowait() + elapsed = final_metric.elapsed + assert isinstance(elapsed, float) + assert elapsed > 0 + assert final_metric.result == expected_result + + # Validate fields that are common across both metrics + expected_uuid = UUID(bytes=bytes(goal_handle.goal_id.uuid)) + for metric in (start_metric, final_metric): + assert metric.action_name == "example_action" + assert metric.node_namespace == "/action_handler" + + # Validate both metrics share the same goal_id + assert metric.goal_id == expected_uuid diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_fail_fast_action_handler.py b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_fail_fast_action_handler.py new file mode 100644 index 0000000..8bcb124 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_fail_fast_action_handler.py @@ -0,0 +1,113 @@ +from collections.abc import Generator +from threading import Event +from typing import Any + +import pytest +from node_helpers import futures +from node_helpers.actions.server import ( + FailFastActionHandler, + SynchronousActionCalledInParallel, +) +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.testing import set_up_node +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import ( + RobustActionExample_GetResult_Response, +) +from rclpy.action.server import ServerGoalHandle + +from .utils import ActionWorkerForTesting + + +class ExampleActionHandler( + FailFastActionHandler[ + RobustActionExample.Goal, + RobustActionExample.Feedback, + RobustActionExample.Result, + ] +): + def __init__(self, node: HelpfulNode): + super().__init__( + node=node, + action_name="example_action", + action_type=RobustActionExample, + ) + self.done_event = Event() + + def create_worker(self, goal_handle: ServerGoalHandle) -> ActionWorkerForTesting: + return ActionWorkerForTesting( + goal_handle=goal_handle, done_event=self.done_event + ) + + +class ExampleActionServer(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_action_server", **kwargs) + self.handler = ExampleActionHandler(self) + + +class ExampleActionClient(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_action_client", **kwargs) + self.client = self.create_robust_action_client( + RobustActionExample, "example_action" + ) + + +@pytest.fixture() +def example_action_server() -> Generator[ExampleActionServer, None, None]: + yield from set_up_node( + ExampleActionServer, + "action_handler", + "example_action_server", + multi_threaded=True, + ) + + +@pytest.fixture() +def example_action_client( + example_action_server: ExampleActionServer, +) -> Generator[ExampleActionClient, None, None]: + yield from set_up_node( + ExampleActionClient, + "action_handler", + "example_action_client", + ) + + +def test_basic_usage( + example_action_client: ExampleActionClient, + example_action_server: ExampleActionServer, +) -> None: + """Test that while an action is being run, any other attempts to run it result in + the expected exception.""" + + first_request, _ = futures.wait_for_send_goal( + example_action_client.client, RobustActionExample.Goal() + ) + second_request, _ = futures.wait_for_send_goal( + example_action_client.client, RobustActionExample.Goal() + ) + + # Ensure the second action runs into an exception + with pytest.raises(RobustRPCException.like(SynchronousActionCalledInParallel)): + futures.wait_for_future(second_request, object, timeout=10) + + # Ensure the first request is still active + assert not first_request.done() + assert second_request.done() + + # Allow the first request to finish + example_action_server.handler.done_event.set() + futures.wait_for_future( + first_request, RobustActionExample_GetResult_Response, timeout=10 + ) + + # Validate a third request can still go through without exceptions + third_request, _ = futures.wait_for_send_goal( + example_action_client.client, RobustActionExample.Goal() + ) + futures.wait_for_future( + third_request, RobustActionExample_GetResult_Response, timeout=10 + ) diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_queued_action_handler.py b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_queued_action_handler.py new file mode 100644 index 0000000..77c2168 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/server/test_queued_action_handler.py @@ -0,0 +1,108 @@ +import queue +from collections.abc import Generator +from threading import Event +from typing import Any + +import pytest +from node_helpers.actions.server import QueuedActionHandler +from node_helpers.futures import wait_for_future +from node_helpers.nodes import HelpfulNode +from node_helpers.testing import set_up_node +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import ( + RobustActionExample_GetResult_Response, +) +from rclpy.action.client import ClientGoalHandle +from rclpy.action.server import ServerGoalHandle +from rclpy.task import Future + +from .utils import ActionWorkerForTesting + + +class ExampleActionHandler( + QueuedActionHandler[ + RobustActionExample.Goal, + RobustActionExample.Feedback, + RobustActionExample.Result, + ] +): + def __init__(self, node: HelpfulNode, done_events: "queue.Queue[Event]"): + super().__init__( + node=node, + action_name="example_action", + action_type=RobustActionExample, + ) + self.done_events = done_events + + def create_worker(self, goal_handle: ServerGoalHandle) -> ActionWorkerForTesting: + done_event = Event() + self.done_events.put(done_event) + return ActionWorkerForTesting(goal_handle=goal_handle, done_event=done_event) + + +class ExampleActionServer(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_action_server", **kwargs) + self.done_events: "queue.Queue[Event]" = queue.Queue() + self.handler = ExampleActionHandler(self, self.done_events) + + +@pytest.fixture() +def example_action_server() -> Generator[ExampleActionServer, None, None]: + yield from set_up_node( + ExampleActionServer, + "action_handler", + "example_action_server", + multi_threaded=True, + ) + + +class ExampleActionClient(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_action_client", **kwargs) + self.client = self.create_robust_action_client( + RobustActionExample, "example_action" + ) + + +@pytest.fixture() +def example_action_client( + example_action_server: ExampleActionServer, +) -> Generator[ExampleActionClient, None, None]: + yield from set_up_node( + ExampleActionClient, + "action_handler", + "example_action_client", + ) + + +def test_queued_work( + example_action_client: ExampleActionClient, + example_action_server: ExampleActionServer, +) -> None: + """Tests that the action server responds to multiple requests one-by-one""" + first_request: Future = wait_for_future( + example_action_client.client.send_goal_async(RobustActionExample.Goal()), + ClientGoalHandle, + timeout=10, + ).get_result_async() + second_request: Future = wait_for_future( + example_action_client.client.send_goal_async(RobustActionExample.Goal()), + ClientGoalHandle, + timeout=10, + ).get_result_async() + + # Ensure that both actions are waiting + assert not first_request.done() + assert not second_request.done() + + # Let the first action go through + example_action_server.done_events.get().set() + wait_for_future(first_request, RobustActionExample_GetResult_Response, timeout=10) + assert first_request.done() + assert not second_request.done() + + # Let the second action go through + example_action_server.done_events.get().set() + wait_for_future(second_request, RobustActionExample_GetResult_Response, timeout=10) + assert second_request.done() diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/server/utils.py b/pkgs/node_helpers/node_helpers_test/integration/actions/server/utils.py new file mode 100644 index 0000000..21f192a --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/server/utils.py @@ -0,0 +1,67 @@ +from collections.abc import Generator +from threading import Event + +from node_helpers.actions.server import ActionWorker +from node_helpers.timing import TestingTimeout as Timeout +from node_helpers_msgs.action import RobustActionExample +from rclpy.action.server import ServerGoalHandle + + +class CoolError(Exception): + pass + + +class ActionWorkerForTesting( + ActionWorker[ + RobustActionExample.Goal, + RobustActionExample.Feedback, + RobustActionExample.Result, + ] +): + """An ActionWorker that can simulate a variety of behaviors based on how it's + configured + """ + + def __init__( + self, + goal_handle: ServerGoalHandle, + done_event: Event | None = None, + raise_exception: bool = False, + on_exception_called: Event | None = None, + ): + """ + :param goal_handle: A handle to the current goal + :param done_event: If provided, the action will wait on this event before + finishing, yielding routinely + :param raise_exception: If provided, the action will raise an exception instead + of completing + :param on_exception_called: Will be set if an exception occurs. This must be + set if raise_exception is True. + """ + super().__init__(goal_handle) + self.done_event = done_event + self.raise_exception = raise_exception + self.on_exception_called = on_exception_called + if self.on_exception_called is not None: + self.on_exception_called.clear() + + def run(self) -> Generator[RobustActionExample.Feedback | None, None, None]: + yield None + if self.raise_exception: + raise CoolError("I might be an error, but dang I'm cool") + + # Wait for a done event if it's set + if self.done_event is not None: + timeout = Timeout(10) + while not self.done_event.wait(timeout=0.1) and timeout: + yield RobustActionExample.Feedback() + + self.result = RobustActionExample.Result(data="success") + + def on_cancel(self) -> RobustActionExample.Result: + return RobustActionExample.Result(data="canceled") + + def on_exception(self, ex: Exception) -> None: + if self.on_exception_called is None: + raise RuntimeError("An exception occurred but the event object is not set") + self.on_exception_called.set() diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/test_action_sequences.py b/pkgs/node_helpers/node_helpers_test/integration/actions/test_action_sequences.py new file mode 100644 index 0000000..8a6f6d8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/test_action_sequences.py @@ -0,0 +1,620 @@ +from threading import Event + +import pytest +from action_msgs.msg import GoalStatus +from node_helpers import futures, testing +from node_helpers.actions import ( + ActionElement, + ActionGroup, + ActionSequence, + AlreadyRunningActionsError, + NoRunningActionsError, + ParallelActionSequences, +) +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.timing import TestingTimeout +from node_helpers_msgs.action import RobustActionExample + +from .server.test_base_handler_and_worker import ( + ExampleActionClient, + ExampleActionServer, + example_action_client, # noqa: F401 + example_action_server, # noqa: F401 +) +from .server.utils import CoolError + + +def test_action_sequence_basic_usage_without_feedback( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test the high-level ActionSequence, ActionGroup functionality, with actions that + don't use 'continue' feedback triggers""" + + group_a = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + group_b = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + sequence = ActionSequence(group_a, group_b) + + results = sequence.execute() + + # Verify actions were run and results match up with the number of actions + assert len(results) == 2 # two groups + assert len(results[0]) == 2 # two actions, group 1 + assert len(results[1]) == 1 # one action, group 2 + assert len(group_a._running_actions) == 0 + assert len(group_b._running_actions) == 0 + + +def test_action_group_send_goals( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Basic test that send_goals sends goals and waits for results to finish""" + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + results = group.send_goals() + assert len(results) == 3 + + # The running futures and action triggers should be reset + assert len(group._running_actions) == 0 + + +def test_cancel_ongoing_goals( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Try cancelling goals that are still running""" + + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + + # It should be okay to cancel when there are no running actions. This is helpful + # because when ActionWorkers are handling their on_cancel's, they can just cancel + # any action groups they have (regardless of whether they were doing anything). + assert len(group._running_actions) == 0 + group.cancel_goals() # No exception + + # Now run the action, ensuring it never finishes by adding a done_event + example_action_server.handler.done_event = Event() + group.send_goals_async() + ongoing_actions = group._running_actions + assert len(ongoing_actions) == 3 + + # Now call cancel_goals and validate all actions ended in cancellation + group.cancel_goals() + + # The group should clear internal running actions after cancellation + assert len(group._running_actions) == 0 + + # All result futures should be completed + assert all(a.result_future.done() for a in ongoing_actions) + + # All clients should report cancelled + assert all( + a.goal_handle.status == GoalStatus.STATUS_CANCELED for a in ongoing_actions + ) + + +def test_cancel_goals_with_finished_futures( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test you can still safely call 'cancel_goals' if the action finished already""" + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + group.send_goals_async() + + # Wait until the underlying actions have finished + futures.wait_for_futures([a.result_future for a in group._running_actions], object) + assert len(group._running_actions) == 2 + assert all(a.result_future.done() for a in group._running_actions) + + # This shouldn't fail + group.cancel_goals() + assert len(group._running_actions) == 0 + + +def test_action_group_accepts_iterables_and_elements( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """A simple test that you can put generators in an ActionGroup. + + For example, the following should be acceptable: + ActionGroup(ActionElement(client=client, ...) for client in clients) + + You should also be able to: + ActionGroup(e for e in elements, element_c, element_d) + """ + element_a = ActionElement(example_action_client.client, RobustActionExample.Goal()) + element_b = ActionElement(example_action_client.client, RobustActionExample.Goal()) + element_c = ActionElement(example_action_client.client, RobustActionExample.Goal()) + + as_list = [element_a, element_b, element_c] + as_generator = (e for e in as_list) + as_mix_of_generator_and_items = ((e for e in (element_a, element_b)), element_c) + + assert ActionGroup(as_list)._action_elements == as_list + assert ActionGroup(*as_list)._action_elements == as_list + assert ActionGroup(as_generator)._action_elements == as_list + assert ActionGroup(*as_mix_of_generator_and_items)._action_elements == as_list + + +def test_action_group_basic_usage_without_feedback( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test the high-level ActionGroup functionality, with actions that don't use + 'continue' feedback triggers. + + It also validates that errors are raised when 'wait_for_results' is called. + """ + + example_action_server.handler.done_event = Event() + + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + assert len(group._running_actions) == 0 + + group.send_goals_async() + assert len(group._running_actions) == 2 + assert all(not r.result_future.done() for r in group._running_actions) + assert all(not r.next_action_trigger.is_set() for r in group._running_actions) + + # Allow the action to finish + example_action_server.handler.done_event.set() + group.wait_for_feedback_triggers() + assert len(group._running_actions) == 2 + assert all(r.next_action_trigger.is_set() for r in group._running_actions) + assert all(r.result_future.done() for r in group._running_actions) + + results = group.wait_for_results() + assert len(results) == 2 + assert len(group._running_actions) == 0 + + +def test_action_group_basic_usage_with_feedback( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test the high-level ActionGroup functionality, with an action group that has the + feedback_callback specified + """ + + example_action_server.handler.done_event = Event() + + # Add a feedback callback that immediately sets the feedback_trigger + def feedback_callback(feedback: RobustActionExample.Feedback) -> bool: + assert isinstance(feedback, RobustActionExample.Feedback) + return True + + group = ActionGroup( + ActionElement( + example_action_client.client, RobustActionExample.Goal(), feedback_callback + ), + ActionElement( + example_action_client.client, RobustActionExample.Goal(), feedback_callback + ), + ) + assert len(group._running_actions) == 0 + + group.send_goals_async() + group.wait_for_feedback_triggers() + + # None of the actions should have their results set yet, but all feedback triggers + # should be set by now + assert len(group._running_actions) == 2 + assert all(r.next_action_trigger.is_set() for r in group._running_actions) + assert all(not r.result_future.done() for r in group._running_actions) + + # All of the events must be unique objects + assert len({r.next_action_trigger for r in group._running_actions}) == 2 + + # Allow the action to finish + example_action_server.handler.done_event.set() + + results = group.wait_for_results() + assert all(isinstance(r, RobustActionExample.Result) for r in results), ( + "The wait_for_results method should return the underlying result objects, not " + "the *_GetResult_Response message!" + ) + assert len(group._running_actions) == 0 + assert len(results) == 2 + + +def test_action_group_exception_cases( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test all the usages of ActionGroup that should end in exceptions""" + + example_action_server.handler.raise_exception = True + + with pytest.raises(ValueError): + # You shouldn't be able to create an ActionGroup without ActionElements + ActionGroup() + + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + + # Waiting for triggers or results should fail if the actions haven't been started + with pytest.raises(NoRunningActionsError): + group.wait_for_feedback_triggers() + + with pytest.raises(NoRunningActionsError): + group.wait_for_results() + + with pytest.raises(NoRunningActionsError): + next(group.yield_for_results(yield_interval=999)) + + group.send_goals_async() + with pytest.raises(AlreadyRunningActionsError): + # Running send_goals_async twice should fail + group.send_goals_async() + + # Validate exceptions are raised in wait_for_feedback_triggers + with pytest.raises(RobustRPCException.like(CoolError)): + group.wait_for_feedback_triggers() + + # Validate exceptions are raised in yield_for_feedback_triggers + with pytest.raises(RobustRPCException.like(CoolError)): + testing.exhaust_generator(group.yield_for_feedback_triggers()) + + # Validate exceptions are raised in yield_for_results + with pytest.raises(RobustRPCException.like(CoolError)): + testing.exhaust_generator(group.yield_for_results(yield_interval=0)) + + # Validate exceptions are raised in wait_for_results + with pytest.raises(RobustRPCException.like(CoolError)): + group.wait_for_results() + + +def test_action_group_yield_for_results( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + # Set a 'done_event' to ensure the actions don't finish until my say so + example_action_server.handler.done_event = Event() + group = ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ) + group.send_goals_async() + + # Normally this would be called like "result = yield from group.yield_for_results()" + # but since we aren't running in an actionworker 'run' loop, we call next manually + yield_generator = group.yield_for_results(yield_interval=0) + + # Check that it yields many times over before a result is given + for _ in range(100): + assert next(yield_generator) is None + + # Validate none of the futures are done yet + assert all(not a.result_future.done() for a in group._running_actions) + + # Now allow all the actions to continue, and continue yielding meanwhile + example_action_server.handler.done_event.set() + with pytest.raises(StopIteration) as error: + while True: + assert next(yield_generator) is None + + assert error.value.value is not None + assert len(group._running_actions) == 0, "All actions should have been cleared!" + + +def test_action_sequence_exception_cases( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + example_action_server.handler.raise_exception = True + sequence = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()) + for _ in range(3) + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + # Test that execute raises an exception + with pytest.raises(RobustRPCException.like(CoolError)): + sequence.execute() + + # Test that yield_for_execution raises an exception + with pytest.raises(RobustRPCException.like(CoolError)): + testing.exhaust_generator(sequence.yield_for_execution()) + + # The first group should have 3 running actions, the second group SHOULD NOT HAVE + # HAD ANY ACTIONS RUN! This basically validates that the exception was caught, + # raised, and the sequence was able to stop the second group from running. + assert len(sequence._action_groups[0]._running_actions) == 3 + assert len(sequence._action_groups[1]._running_actions) == 0 + + # Test that wait_for_results raises an exception + with pytest.raises(RobustRPCException.like(CoolError)): + sequence.wait_for_results() + + +def test_parallel_action_sequences_basic_execution( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test basic parallel execution of multiple action sequences.""" + + # Create two action sequences + sequence1 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + sequence2 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + # Create ParallelActionSequences with the two sequences + parallel_sequences = ParallelActionSequences(sequence1, sequence2) + + # Execute the sequences in parallel + _, results = testing.exhaust_generator(parallel_sequences.yield_for_execution()) + + # Verify that results are returned for both sequences + assert len(results) == 2 # Two sequences + assert len(results[0]) == 2 # Two action groups in sequence1 + assert len(results[1]) == 2 # Two action groups in sequence2 + + # Verify that each action group has one result + assert all(len(group_results) == 1 for group_results in results[0]) + assert all(len(group_results) == 1 for group_results in results[1]) + + +def test_parallel_action_sequences_with_feedback( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test parallel execution of action sequences with feedback callbacks.""" + + # Event to control when feedback is received + feedback_event = Event() + + # Feedback callback that waits for the event to be set + def feedback_callback(feedback: RobustActionExample.Feedback) -> bool: + return feedback_event.is_set() + + # Set up sequences with feedback callbacks + sequence1 = ActionSequence( + ActionGroup( + ActionElement( + example_action_client.client, + RobustActionExample.Goal(), + feedback_callback=feedback_callback, + ), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + sequence2 = ActionSequence( + ActionGroup( + ActionElement( + example_action_client.client, + RobustActionExample.Goal(), + feedback_callback=feedback_callback, + ), + ), + ) + + # Start the parallel sequences + parallel_sequences = ParallelActionSequences(sequence1, sequence2) + + # Since we're simulating, we'll step through the execution manually + sequence_generators = [ + seq.yield_for_execution() for seq in parallel_sequences._action_sequences + ] + + # Initially, feedback callbacks will block waiting for feedback_event + # Let's step through a few iterations + for _ in range(100): + for gen in sequence_generators: + assert next(gen) is None + + # At this point, both sequences should be waiting at the feedback trigger + # Set the feedback event to allow them to proceed + feedback_event.set() + + # Continue stepping through execution + result_count = 0 + timeout = TestingTimeout(10) + for sequence_idx, gen in enumerate(sequence_generators): + while timeout: + try: + next(gen) + + except StopIteration as finished: + result_count += 1 + assert finished.value is not None # noqa: PT017 + if sequence_idx == 0: + assert len(finished.value) == 3 # noqa: PT017 + else: + assert len(finished.value) == 1 # noqa: PT017 + break + + assert result_count == 2 # 2 action groups total + + +def test_parallel_action_sequences_exception_cases( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test that exceptions in one sequence do not affect others.""" + + # Configure the server to raise an exception when processing goals + example_action_server.handler.raise_exception = True + + # Set up sequences + sequence1 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + sequence2 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + # Start the parallel sequences + parallel_sequences = ParallelActionSequences(sequence1, sequence2) + + # Execute and expect an exception + with pytest.raises(RobustRPCException.like(CoolError)): + testing.exhaust_generator(parallel_sequences.yield_for_execution()) + + # Verify that all sequences ended, essentially verifying that the exception in + # parallel sequences is caught, stored, and re-raised once the system has fully + # synchronized + for sequence in parallel_sequences._action_sequences: + for action_group in sequence._action_groups: + for running_action in action_group._running_actions: + assert running_action.result_future.done() + + +def test_parallel_action_sequences_cancellation( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test that ParallelActionSequences can be canceled properly.""" + + # Create an event to control when the actions finish + example_action_server.handler.done_event = Event() + + # Create two action sequences + sequence1 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + sequence2 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ActionElement(example_action_client.client, RobustActionExample.Goal()), + ), + ) + + # Create ParallelActionSequences with the two sequences + parallel_sequences = ParallelActionSequences(sequence1, sequence2) + + # Edit the action server with a separate done event, so it never finishes + example_action_server.handler.done_event = Event() + + # Start the parallel sequences + execute_generator = parallel_sequences.yield_for_execution() + next(execute_generator) + + # Verify the underlying actions are running, and the correct ones at that + assert len(sequence1._action_groups[0]._running_actions) == 2 + assert all( + not a.result_future.done() for a in sequence1._action_groups[0]._running_actions + ) + assert len(sequence1._action_groups[1]._running_actions) == 0 + + assert len(sequence2._action_groups[0]._running_actions) == 1 + assert all( + not a.result_future.done() for a in sequence2._action_groups[0]._running_actions + ) + assert len(sequence2._action_groups[1]._running_actions) == 0 + + # Cancel goals + parallel_sequences.cancel() + + # Verify that all actions are canceled + for sequence in parallel_sequences._action_sequences: + for action_group in sequence._action_groups: + assert len(action_group._running_actions) == 0 + + +def test_parallel_action_sequences_exception_prevents_other_actions( + example_action_server: ExampleActionServer, # noqa: F811 + example_action_client: ExampleActionClient, # noqa: F811 +) -> None: + """Test that exceptions prevent other actions from being called.""" + + # Configure the server to raise an exception when processing goals + example_action_server.handler.raise_exception = True + + # Set up sequences + sequence1 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()) + for _ in range(2) + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()) + for _ in range(3) + ), + ) + + sequence2 = ActionSequence( + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()) + for _ in range(4) + ), + ActionGroup( + ActionElement(example_action_client.client, RobustActionExample.Goal()) + for _ in range(5) + ), + ) + + # Start the parallel sequences + parallel_sequences = ParallelActionSequences(sequence1, sequence2) + + # Execute and expect an exception + with pytest.raises(RobustRPCException.like(CoolError)): + testing.exhaust_generator(parallel_sequences.yield_for_execution()) + + # Verify that no further actions were called after the exception + assert len(sequence1._action_groups[0]._running_actions) == 2 + assert len(sequence1._action_groups[1]._running_actions) == 0 + assert len(sequence2._action_groups[0]._running_actions) == 4 + assert len(sequence2._action_groups[1]._running_actions) == 0 diff --git a/pkgs/node_helpers/node_helpers_test/integration/actions/test_context_manager.py b/pkgs/node_helpers/node_helpers_test/integration/actions/test_context_manager.py new file mode 100644 index 0000000..6f2fe7a --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/actions/test_context_manager.py @@ -0,0 +1,208 @@ +from collections.abc import Generator +from itertools import cycle +from queue import Queue +from typing import Any + +import pytest +from node_helpers.actions.context_manager import ActionContextManager +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.testing import ( + ActionServerCallback, + ConfigurableServiceCallback, + set_up_node, +) +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import RobustActionExample_Feedback +from rclpy.action.server import CancelResponse +from rclpy.callback_groups import ReentrantCallbackGroup + + +class ExampleServerNode(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_server_node", **kwargs) + + # Create "action_a", a very basic action that calls the action_callback + self.action_a_callback = ActionServerCallback( + return_values=cycle([RobustActionExample.Result()]), + feedback_values=cycle([RobustActionExample.Feedback()]), + ) + self.action_a_callback.allow_publish_feedback.set() + self.create_robust_action_server( + RobustActionExample, + "/test_context_manager/action_a", + self.action_a_callback, + callback_group=ReentrantCallbackGroup(), + cancel_callback=lambda *args: CancelResponse.ACCEPT, + ) + + # Create "action_b", a very basic action that calls the action_callback + self.action_b_callback = ActionServerCallback( + return_values=cycle([RobustActionExample.Result()]), + feedback_values=cycle([RobustActionExample.Feedback()]), + ) + self.action_b_callback.allow_publish_feedback.set() + self.create_robust_action_server( + RobustActionExample, + "/test_context_manager/action_b", + self.action_b_callback, + callback_group=ReentrantCallbackGroup(), + cancel_callback=lambda *args: CancelResponse.ACCEPT, + ) + + +@pytest.fixture() +def example_server_node() -> Generator[ExampleServerNode, None, None]: + yield from set_up_node( + node_class=ExampleServerNode, + namespace="test_context_manager", + node_name="example_server_node", + multi_threaded=True, + ) + + +class ExampleClientNode(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_client_node", **kwargs) + + self.action_a_client = self.create_robust_action_client( + RobustActionExample, "/test_context_manager/action_a" + ) + self.action_b_client = self.create_robust_action_client( + RobustActionExample, "/test_context_manager/action_b" + ) + + +@pytest.fixture() +def example_client_node() -> Generator[ExampleClientNode, None, None]: + yield from set_up_node( + node_class=ExampleClientNode, + namespace="test_context_manager", + node_name="example_client_node", + multi_threaded=True, + ) + + +def test_basic_operation( + example_server_node: ExampleServerNode, example_client_node: ExampleClientNode +) -> None: + """Tests that the ActionContextManager interacts properly with the action""" + action_callback = example_server_node.action_a_callback + assert not action_callback.on_action_started.is_set() + + with ActionContextManager( + example_client_node.action_a_client, RobustActionExample.Goal(), timeout=10 + ): + # Ensure that the action has started and no cancellation has been requested + assert action_callback.on_action_started.is_set() + assert not action_callback.on_cancel_requested.is_set() + + # Allow cancellation before exiting the context + action_callback.allow_cancel.set() + + assert action_callback.call_count == 1 + assert action_callback.on_cancel_requested.is_set() + + +def test_async( + example_server_node: ExampleServerNode, example_client_node: ExampleClientNode +) -> None: + """Tests that two ActionContextManagers can be called concurrently using its async + mode + """ + + call_a = ActionContextManager( + example_client_node.action_a_client, + RobustActionExample.Goal(), + timeout=10, + async_=True, + ) + call_b = ActionContextManager( + example_client_node.action_b_client, + RobustActionExample.Goal(), + timeout=10, + async_=True, + ) + callback_a = example_server_node.action_a_callback + callback_b = example_server_node.action_b_callback + + with call_a, call_b: + call_a.wait_for_feedback() + call_b.wait_for_feedback() + + # Both actions should be running right now + assert callback_a.on_action_started.is_set() + assert callback_b.on_action_started.is_set() + assert not callback_a.on_cancel_requested.is_set() + assert not callback_b.on_cancel_requested.is_set() + + # Allow cancellation before exiting the context + callback_a.allow_cancel.set() + callback_b.allow_cancel.set() + + call_a.wait_for_cancellation() + call_b.wait_for_cancellation() + + assert callback_a.on_cancel_requested.is_set() + assert callback_b.on_cancel_requested.is_set() + + # Ensure that these wait_for_* methods can be called multiple times + call_a.wait_for_feedback() + call_a.wait_for_cancellation() + + +def test_check_for_exceptions( + example_server_node: ExampleServerNode, example_client_node: ExampleClientNode +) -> None: + """Test that 'check_for_exceptions' nonblockingly checks for exceptions""" + + action_callback = example_server_node.action_a_callback + action_callback._return_value_iterator = ConfigurableServiceCallback( + [RobustActionExample.Result(error_name="SomeActionError")] + ) + + assert not action_callback.on_action_started.is_set() + + # An exception should be raised when the context is exited + expected_exception = RobustRPCException.like("SomeActionError") + with ( + pytest.raises(expected_exception), + ActionContextManager( + example_client_node.action_a_client, RobustActionExample.Goal(), timeout=10 + ) as handle, + ): + # Ensure that the action has started and no cancellation has been requested + assert action_callback.on_action_started.is_set() + assert not action_callback.on_cancel_requested.is_set() + + # Allow cancellation before exiting the context + action_callback.allow_abort.set() + + # 'check_for_exceptions' should also work + with pytest.raises(expected_exception): + handle.check_for_exceptions() + + +def test_feedback_is_passed_through( + example_server_node: ExampleServerNode, example_client_node: ExampleClientNode +) -> None: + feedback_queue: Queue[RobustActionExample_Feedback] = Queue() + action_callback = example_server_node.action_a_callback + + context = ActionContextManager( + example_client_node.action_a_client, + RobustActionExample.Goal(), + timeout=10, + on_feedback=feedback_queue.put, + ) + + with context: + # Allow cancellation before exiting the context + action_callback.allow_cancel.set() + + assert action_callback.call_count == 1 + assert action_callback.on_cancel_requested.is_set() + + # Ensure that the feedback was passed through + assert not feedback_queue.empty() + assert feedback_queue.get() == RobustActionExample.Feedback() diff --git a/pkgs/node_helpers/node_helpers_test/integration/async_tools/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/async_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/async_tools/test_async_adapter.py b/pkgs/node_helpers/node_helpers_test/integration/async_tools/test_async_adapter.py new file mode 100644 index 0000000..bee81a0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/async_tools/test_async_adapter.py @@ -0,0 +1,98 @@ +import asyncio +import queue +from collections.abc import Generator +from typing import Any + +import pytest +from node_helpers.async_tools import AsyncAdapter +from node_helpers.futures import wait_for_future +from node_helpers.nodes import HelpfulNode +from node_helpers.testing import set_up_node +from node_helpers_msgs.msg import SensorExample +from node_helpers_msgs.srv import RobustServiceExample +from rclpy.qos import qos_profile_services_default + + +class ExampleAsyncNode(HelpfulNode): + """A node that uses async callbacks for topic subscriptions and services""" + + def __init__(self, **kwargs: Any): + super().__init__("example_async_node", **kwargs) + + self.async_adapter = AsyncAdapter(self) + + self.message_received: queue.Queue[SensorExample] = queue.Queue() + self.service_received: queue.Queue[RobustServiceExample.Request] = queue.Queue() + + self.create_subscription( + SensorExample, + "sensor", + self.async_adapter.adapt(self.on_message), + qos_profile=qos_profile_services_default, + ) + + self.create_service( + RobustServiceExample, "service", self.async_adapter.adapt(self.on_service) + ) + + async def on_message(self, message: SensorExample) -> None: + await asyncio.sleep(0) + self.message_received.put(message) + + async def on_service( + self, + request: RobustServiceExample.Request, + response: RobustServiceExample.Response, + ) -> RobustServiceExample.Response: + await asyncio.sleep(0) + self.service_received.put(request) + return response + + +class ExampleAsyncNodeClient(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_async_node_client", **kwargs) + + self.send_message = self.create_publisher( + SensorExample, "sensor", qos_profile=qos_profile_services_default + ) + + self.call_service = self.create_client(RobustServiceExample, "service") + + +@pytest.fixture() +def example_async_node() -> Generator[ExampleAsyncNode, None, None]: + yield from set_up_node( + node_class=ExampleAsyncNode, + namespace="async_adapter", + node_name="example_async_node", + multi_threaded=True, + ) + + +@pytest.fixture() +def example_async_node_client() -> Generator[ExampleAsyncNodeClient, None, None]: + yield from set_up_node( + node_class=ExampleAsyncNodeClient, + namespace="async_adapter", + node_name="example_async_node_client", + ) + + +def test_callback_adapter( + example_async_node: ExampleAsyncNode, + example_async_node_client: ExampleAsyncNodeClient, +) -> None: + """Tests that the callback adapter works for topics and services""" + + # Topic subscription + example_async_node_client.send_message.publish(SensorExample(value=1)) + message = example_async_node.message_received.get(timeout=5.0) + assert message.value == 1 + + # Service + future = example_async_node_client.call_service.call_async( + RobustServiceExample.Request() + ) + example_async_node.service_received.get(timeout=5.0) + wait_for_future(future, RobustServiceExample.Response, timeout=5.0) diff --git a/pkgs/node_helpers/node_helpers_test/integration/conftest.py b/pkgs/node_helpers/node_helpers_test/integration/conftest.py new file mode 100644 index 0000000..ce6ad32 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/conftest.py @@ -0,0 +1 @@ +from node_helpers.testing import each_test_setup_teardown # noqa: F401 diff --git a/pkgs/node_helpers/node_helpers_test/integration/destruction/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/destruction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/destruction/conftest.py b/pkgs/node_helpers/node_helpers_test/integration/destruction/conftest.py new file mode 100644 index 0000000..f8a4445 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/destruction/conftest.py @@ -0,0 +1,20 @@ +from collections.abc import Generator +from typing import Any + +import pytest +from node_helpers.nodes import HelpfulNode +from node_helpers.testing import set_up_node + + +class ExampleNode(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("example_async_node_client", **kwargs) + + +@pytest.fixture() +def example_node() -> Generator[ExampleNode, None, None]: + yield from set_up_node( + node_class=ExampleNode, + namespace="async_adapter", + node_name="example_async_node", + ) diff --git a/pkgs/node_helpers/node_helpers_test/integration/destruction/test_mixin.py b/pkgs/node_helpers/node_helpers_test/integration/destruction/test_mixin.py new file mode 100644 index 0000000..3e3c060 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/destruction/test_mixin.py @@ -0,0 +1,11 @@ +from unittest.mock import MagicMock + +from .conftest import ExampleNode + + +def test_destroy_callbacks(example_node: ExampleNode) -> None: + on_destroy_callback = MagicMock() + example_node.on_destroy(on_destroy_callback) + example_node.destroy_node() + + assert on_destroy_callback.call_count == 1 diff --git a/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/conftest.py b/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/conftest.py new file mode 100644 index 0000000..9b2c48e --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/conftest.py @@ -0,0 +1,30 @@ +from collections.abc import Generator +from typing import Any + +import pytest +from node_helpers.interaction import DashboardMenu +from node_helpers.testing import ( # noqa: F401 + NodeForTesting, + each_test_setup_teardown, + set_up_node, +) +from node_helpers_msgs.srv import ChoosePromptOption + + +class MenuClient(NodeForTesting): + def __init__(self, **kwargs: Any): + super().__init__("menu_client", **kwargs) + + self.choose_option = self.create_robust_client( + srv_type=ChoosePromptOption, srv_name=DashboardMenu.DEFAULT_OPTION_SERVICE + ) + + +@pytest.fixture() +def menu_client() -> Generator[MenuClient, None, None]: + yield from set_up_node( + node_class=MenuClient, + namespace="", + node_name="menu_client", + multi_threaded=True, + ) diff --git a/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/test_dashboard.py b/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/test_dashboard.py new file mode 100644 index 0000000..22302e1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/interaction/menu/test_dashboard.py @@ -0,0 +1,232 @@ +import json +from collections.abc import Generator +from queue import Queue +from threading import Event +from time import sleep +from typing import Any +from unittest import mock + +import pytest +from action_msgs.msg import GoalStatus +from node_helpers.interaction import DashboardMenu +from node_helpers.interaction.menus.base_menu import DEFAULT_CANCEL_OPTION +from node_helpers.nodes import HelpfulNode +from node_helpers.testing import DynamicContextThread, set_up_node +from node_helpers.timing import TestingTimeout as Timeout +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import ( + RobustActionExample_GetResult_Response, +) +from node_helpers_msgs.msg import PromptOption, UserPrompt +from node_helpers_msgs.srv import ChoosePromptOption +from rclpy.action import CancelResponse +from rclpy.action.server import ServerGoalHandle + +from .conftest import MenuClient + + +class SimpleActionNode(HelpfulNode): + """Hosts an arbitrary action server and client""" + + EXPECTED_ACTION_RESULT_DATA = "cool-data" + + def __init__(self, **kwargs: Any): + super().__init__(node_name="simple_node", **kwargs) + self.action_continue_event = Event() + self.action_cancel_requested = Event() + """When set, the action call will continue and finish""" + + self.server = self.create_robust_action_server( + RobustActionExample, + "cool_action", + execute_callback=self.execute_action, + cancel_callback=lambda *args: CancelResponse.ACCEPT, + ) + self.client = self.create_robust_action_client( + RobustActionExample, "cool_action" + ) + + def execute_action(self, goal: ServerGoalHandle) -> RobustActionExample.Result: + while not self.action_continue_event.is_set(): + if goal.is_cancel_requested: + self.action_cancel_requested.set() + goal.canceled() + return RobustActionExample.Result() + sleep(0.1) + + goal.succeed() + return RobustActionExample.Result(data=self.EXPECTED_ACTION_RESULT_DATA) + + +@pytest.fixture() +def simple_action_node() -> Generator[SimpleActionNode, None, None]: + yield from set_up_node( + node_class=SimpleActionNode, + namespace="", + node_name="simple_node", + multi_threaded=True, + ) + + +@pytest.fixture() +def menu(menu_client: MenuClient) -> DashboardMenu: + menu = DashboardMenu(node=menu_client) + menu.connect() + return menu + + +def test_display_menu(menu_client: MenuClient, menu: DashboardMenu) -> None: + very_helpful_message = "dingus" + + expected_prompt = UserPrompt( + options=[ + PromptOption( + name="do an A thing", + description="A", + ), + PromptOption( + name="do a B thing", + description="B", + ), + ], + help=very_helpful_message, + metadata=json.dumps({}), + ) + + prompt_publisher: "Queue[UserPrompt]" = Queue() + selected_a = mock.Mock() + selected_b = mock.Mock() + + menu.publish_prompt = prompt_publisher.put + + def display_menu() -> None: + menu.display_menu( + (expected_prompt.options[0], selected_a), + (expected_prompt.options[1], selected_b), + help_message=very_helpful_message, + ) + + with DynamicContextThread(target=display_menu): + wait_for_menu_readyness(menu) + + assert selected_b.call_count == 0 + menu_client.choose_option.call( + ChoosePromptOption.Request(option=expected_prompt.options[1]) + ) + + assert selected_a.call_count == 0 + assert selected_b.call_count == 1 + assert prompt_publisher.get() == expected_prompt + + +@pytest.mark.parametrize("user_cancelled", (True, False)) +def test_run_cancellable_action( + simple_action_node: SimpleActionNode, + menu_client: MenuClient, + user_cancelled: bool, + menu: DashboardMenu, +) -> None: + prompt_publisher: "Queue[UserPrompt]" = Queue() + menu.publish_prompt = prompt_publisher.put + + output: "Queue[tuple[RobustActionExample_GetResult_Response, bool]]" = Queue() + + def run_action() -> None: + output.put( + menu.run_user_cancellable_action( + simple_action_node.client, + RobustActionExample.Goal(), + ) + ) + + with DynamicContextThread(run_action): + assert ( + len(prompt_publisher.get(timeout=5).options) == 1 + ), "The user should be prompted to cancel!" + assert output.qsize() == 0 + assert prompt_publisher.qsize() == 0 + + if user_cancelled: + wait_for_menu_readyness(menu) + + # Select a choice + menu_client.choose_option.call( + ChoosePromptOption.Request(option=DEFAULT_CANCEL_OPTION) + ) + + # Wait for the action to be cancelled + simple_action_node.action_cancel_requested.wait(5) + + # Allow the action to continue + simple_action_node.action_continue_event.set() + result, cancelled = output.get(timeout=15) + + assert cancelled is user_cancelled + assert prompt_publisher.qsize() == 1, "There should be a 'finishing action' log!" + + if user_cancelled: + assert result.status == GoalStatus.STATUS_CANCELED + assert result.result.data == "" + else: + assert result.status == GoalStatus.STATUS_SUCCEEDED + assert result.result.data == simple_action_node.EXPECTED_ACTION_RESULT_DATA + + +def test_run_cancellable_action_with_custom_menu_items( + simple_action_node: SimpleActionNode, menu_client: MenuClient, menu: DashboardMenu +) -> None: + """Test that when using custom menu items, the callbacks are processed as expected + and cancellation still occurs""" + prompt_publisher: "Queue[UserPrompt]" = Queue() + menu.publish_prompt = prompt_publisher.put + + output: "Queue[tuple[RobustActionExample_GetResult_Response, bool]]" = Queue() + + option_1_event = Event() + option_2_event = Event() + option_3_event = Event() + option_1 = PromptOption(name="1") + option_2 = PromptOption(name="2") + option_3 = PromptOption(name="3") + + def run_action() -> None: + output.put( + menu.run_user_cancellable_action( + simple_action_node.client, + RobustActionExample.Goal(), + menu_items=( + (option_1, option_1_event.set), + (option_2, option_2_event.set), + (option_3, option_3_event.set), + ), + ) + ) + + with DynamicContextThread(run_action): + wait_for_menu_readyness(menu) + + # Select the second option out of the three + menu_client.choose_option.call(ChoosePromptOption.Request(option=option_2)) + + # Wait for the action to be cancelled + simple_action_node.action_cancel_requested.wait(5) + + # Allow the action to continue + simple_action_node.action_continue_event.set() + result, cancelled = output.get(timeout=15) + + # Validate that even though some arbitrary option got picked, it was cancelled + assert cancelled + # Validate the expected option was picked + assert not option_1_event.is_set() + assert option_2_event.is_set() + assert not option_3_event.is_set() + # Validate the prompt included three options + assert len(prompt_publisher.get().options) == 3 + + +def wait_for_menu_readyness(menu: DashboardMenu) -> None: + """Wait for the menu to be ready for user input""" + timeout = Timeout(5) + while menu._prompter._ongoing_request is None and timeout: # type: ignore + sleep(0.05) diff --git a/pkgs/node_helpers/node_helpers_test/integration/interaction/prompting/test_dashboard_prompter.py b/pkgs/node_helpers/node_helpers_test/integration/interaction/prompting/test_dashboard_prompter.py new file mode 100644 index 0000000..b1dcd43 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/interaction/prompting/test_dashboard_prompter.py @@ -0,0 +1,121 @@ +from collections.abc import Generator +from time import sleep +from typing import Any + +import pytest +from node_helpers.interaction import DashboardMenu, DashboardPrompter +from node_helpers.interaction.prompting.dashboard_prompter import InvalidPromptError +from node_helpers.nodes import HelpfulNode +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.testing import set_up_node +from node_helpers.timing import TestingTimeout as Timeout +from node_helpers_msgs.msg import PromptOption +from node_helpers_msgs.srv import ChoosePromptOption + +_OPTION_A = PromptOption(name="A", description="ayy lmao") +_OPTION_B = PromptOption(name="B", description="ayy lmao") +_OPTION_C = PromptOption(name="C", description="ayy lmao") +_OPTION_D = PromptOption(name="D", description="ayy lmao") + + +class PrompterNode(HelpfulNode): + def __init__(self, **kwargs: Any): + super().__init__("prompter", **kwargs) + self.choose_option = self.create_robust_client( + srv_type=ChoosePromptOption, srv_name=DashboardMenu.DEFAULT_OPTION_SERVICE + ) + self.prompter = DashboardPrompter(self, DashboardMenu.DEFAULT_OPTION_SERVICE) + self.prompter.connect() + + +@pytest.fixture() +def prompter_node() -> Generator[PrompterNode, None, None]: + yield from set_up_node( + PrompterNode, namespace="prompter", node_name="prompter", multi_threaded=True + ) + + +@pytest.mark.parametrize( + ("options", "chosen_option", "expected_result"), + ( + # Basic multi-option test + ( + ((_OPTION_A, 1), (_OPTION_B, 2), (_OPTION_C, 3), (_OPTION_D, 4)), + _OPTION_C, + 3, + ), + # Single option test + (((_OPTION_A, "choice"),), _OPTION_A, "choice"), + ), +) +def test_choose( + prompter_node: PrompterNode, + options: tuple[tuple[PromptOption, Any]], + chosen_option: PromptOption, + expected_result: Any, +) -> None: + choice_future = prompter_node.prompter.choose_async(options) + assert not choice_future.done() + + # Choose the option. This also tests that only the "name" is looked at, and that the + # description does not matter. + prompter_node.choose_option.call(ChoosePromptOption.Request(option=chosen_option)) + assert choice_future.done() + assert choice_future.result() == expected_result + + +def test_choose_invalid_option(prompter_node: PrompterNode) -> None: + # Test requesting an option when there are no prompts ongoing + with pytest.raises(RobustRPCException.like(InvalidPromptError)): + prompter_node.choose_option.call(ChoosePromptOption.Request(option=_OPTION_C)) + + # Test requesting an option when the option isn't the one being requested for + choice_future = prompter_node.prompter.choose_async(((_OPTION_C, "option c"),)) + with pytest.raises(RobustRPCException.like(InvalidPromptError)): + prompter_node.choose_option.call(ChoosePromptOption.Request(option=_OPTION_A)) + + assert not choice_future.done() + + # Now validate that getting the correct option _does_ work + prompter_node.choose_option.call(ChoosePromptOption.Request(option=_OPTION_C)) + assert choice_future.result() == "option c" + + +def test_disconnect(prompter_node: PrompterNode) -> None: + choice_future = prompter_node.prompter.choose_async(((_OPTION_A, 1),)) + + # Assert initial state, which will all be torn down by the end + assert prompter_node.choose_option.wait_for_service(10) + assert prompter_node.choose_option.service_is_ready() + assert prompter_node.prompter._ongoing_request is not None + assert not choice_future.done() + + # Close the prompter + prompter_node.prompter.disconnect() + + # The ongoing request should have been cancelled + assert prompter_node.prompter._ongoing_request is None + + # The current future should have an exception set + assert choice_future.done() + with pytest.raises(RuntimeError): + choice_future.result() + + # Wait for the service to be destroyed + timeout = Timeout(5) + while timeout and prompter_node.choose_option.service_is_ready(): + sleep(0.1) + assert not prompter_node.choose_option.service_is_ready() + + +def test_connecting_and_disconnecting(prompter_node: PrompterNode) -> None: + new_prompter = DashboardPrompter( + prompter_node, DashboardMenu.DEFAULT_OPTION_SERVICE + ) + assert new_prompter._service is None, "Service should not be created in the init" + + new_prompter.connect() + assert new_prompter._service is not None, "Service should be created after connect" + + new_prompter.disconnect() + assert new_prompter._service is None, "Service should be destroyed after disconnect" diff --git a/pkgs/node_helpers/node_helpers_test/integration/nodes/test_interactive_transform_publisher.py b/pkgs/node_helpers/node_helpers_test/integration/nodes/test_interactive_transform_publisher.py new file mode 100644 index 0000000..3ed2909 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/nodes/test_interactive_transform_publisher.py @@ -0,0 +1,467 @@ +import contextlib +import logging +from collections.abc import Generator +from queue import Empty +from tempfile import NamedTemporaryFile +from typing import Any + +import numpy as np +import pytest +from builtin_interfaces.msg import Time +from geometry_msgs.msg import Point, Pose, Quaternion, TransformStamped +from interactive_markers.interactive_marker_server import MarkerContext +from node_helpers.nodes.interactive_transform_publisher import ( + DuplicateTransformError, + InteractiveTransformPublisher, + MultipleParentsError, + TransformModel, + TransformsFile, +) +from node_helpers.nodes.interactive_transform_publisher.client import ( + InteractiveTransformClient, +) +from node_helpers.ros2_numpy import msgify, numpify +from node_helpers.testing import NodeForTesting, set_up_node +from rclpy import Parameter +from rclpy.qos import ( + DurabilityPolicy, + HistoryPolicy, + QoSProfile, + qos_profile_services_default, +) +from tf2_msgs.msg import TFMessage +from visualization_msgs.msg import InteractiveMarker, InteractiveMarkerFeedback + +_DEFAULT_TRANSFORMS = ["april_tag:april_tag_origin", "cool_parent:cool_child"] + + +class TransformClient(NodeForTesting, InteractiveTransformClient): + def __init__(self, **kwargs: Any) -> None: + super().__init__("transform_client", **kwargs) # type: ignore + InteractiveTransformClient.__init__(self, self) + + self.tf_static = self.create_queue_subscription( + type_=TFMessage, + topic="/tf_static", + qos_profile=QoSProfile( + depth=100, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_ALL, + ), + ) + self.feedback = self.create_publisher( + InteractiveMarkerFeedback, + "feedback", + qos_profile=qos_profile_services_default, + ) + + +_TRANSFORM_1 = TransformModel( + parent="april_tag", + child="april_tag_origin", + created_via_api=False, + translation=(1.0, 2.0, 3.0), + rotation=(0.1, 2.0, 0.3, 1.0), +) +_TRANSFORM_2 = TransformModel( + parent="cool_parent", child="cool_child", created_via_api=False +) +_TRANSFORM_3 = TransformModel( + parent="darn", child="diggity", created_via_api=True, translation=(7.0, 8.0, 9.0) +) +_TRANSFORMS_FILE = TransformsFile(transforms=[_TRANSFORM_1, _TRANSFORM_2, _TRANSFORM_3]) + + +@contextlib.contextmanager +def set_up_interactive_transform_publisher( + rosparam_transforms: list[str], +) -> Generator[InteractiveTransformPublisher, None, None]: + """Start an InteractiveTransformPublisher with a tempdir and specific config""" + with NamedTemporaryFile("w") as config_file: + config_file.write(_TRANSFORMS_FILE.model_dump_json()) + config_file.flush() + + generator = set_up_node( + InteractiveTransformPublisher, + "calibration", + "interactive_transform_publisher", + parameter_overrides=[ + Parameter(name="static_transforms_file", value=config_file.name), + Parameter(name="transforms", value=rosparam_transforms), + # Disable constant re-publishing of TFs to make testing easier + Parameter(name="tf_publish_frequency", value=0.00001), + ], + ) + + try: + yield next(generator) + finally: + tuple(generator) + + +@pytest.fixture() +def transform_publisher() -> Generator[InteractiveTransformPublisher, None, None]: + with set_up_interactive_transform_publisher(_DEFAULT_TRANSFORMS) as publisher: + yield publisher + + +@pytest.fixture() +def transform_client() -> Generator[TransformClient, None, None]: + yield from set_up_node(TransformClient, "calibration", "transform_client") + + +def test_initialization( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """Test the contents of the __init__, which should load from a file, create a + marker server, and publish static transforms.""" + transform_models = [t.model for t in transform_publisher.transforms] + assert transform_models == [_TRANSFORM_1, _TRANSFORM_2, _TRANSFORM_3] + assert len(transform_publisher.interaction_server.marker_contexts) == 3 + + for transform_description in transform_publisher.transforms: + transform = transform_description.model + + marker_name = transform_description.model.marker_name + marker_context: MarkerContext = ( + transform_publisher.interaction_server.marker_contexts[marker_name] + ) + marker = marker_context.int_marker + assert marker.header.frame_id == transform.parent + assert np.allclose( + numpify(marker.pose.position), + np.array(transform.translation) / transform_publisher.scale_factor, + ) + assert np.allclose(numpify(marker.pose.orientation), transform.rotation) + + # Validate static transform are published, and they match expectations + for transform_description in transform_publisher.transforms: + tf_message = transform_client.tf_static.get(timeout=10) + transform = transform_description.model + + assert len(tf_message.transforms) == 1 + published_tf: TransformStamped = tf_message.transforms[0] + assert published_tf.header.frame_id == transform.parent + assert published_tf.child_frame_id == transform.child + assert np.allclose( + numpify(published_tf.transform.translation), + transform.translation, + ) + assert np.allclose(numpify(published_tf.transform.rotation), transform.rotation) + + +def test_duplicate_transforms( + transform_publisher: InteractiveTransformPublisher, +) -> None: + some_new_transform = TransformModel( + parent="oh im a cool parent, for sure", + child="stop it! you're embarrassing me!", + created_via_api=False, + ) + # No exception + transform_publisher._register_transform(some_new_transform) + + transforms_before = transform_publisher.transforms.copy() + + # The duplicate should fail + with pytest.raises(DuplicateTransformError): + transform_publisher._register_transform(some_new_transform) + + # Nothing should have changed + assert transforms_before == transform_publisher.transforms + + +def test_multiple_parents_not_allowed( + transform_publisher: InteractiveTransformPublisher, +) -> None: + transform_a = TransformModel( + parent="now now child, it's important you don't go off with strangers", + child="impressionable child", + created_via_api=False, + ) + transform_b_with_different_parent = TransformModel( + parent="stranger (oh no, danger!)", + child="impressionable child", + created_via_api=False, + ) + + # No exception + transform_publisher._register_transform(transform_a) + + transforms_before = transform_publisher.transforms.copy() + + # The duplicate should fail + with pytest.raises(MultipleParentsError): + transform_publisher._register_transform(transform_b_with_different_parent) + + # Nothing should have changed + assert transforms_before == transform_publisher.transforms + + +def test_feedback( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """Test functionality in the feedback callback.""" + expected_rotation_1 = np.array((0.4, 0.2, 0.3, 0.5)) + expected_translation_1 = np.array((1.0, 1.1, 1.2)) + expected_marker_name = _TRANSFORM_1.marker_name + published_feedback = InteractiveMarkerFeedback( + marker_name=expected_marker_name, + pose=Pose( + position=msgify(Point, expected_translation_1), + orientation=msgify(Quaternion, expected_rotation_1), + ), + event_type=InteractiveMarkerFeedback.POSE_UPDATE, + ) + + # First clear the transforms that were published in the __init__ + for _ in _TRANSFORMS_FILE.transforms: + transform_client.tf_static.get(timeout=5) + + # Publish feedback and validate the response + transform_client.feedback.publish(published_feedback) + retrieved_transforms = transform_client.tf_static.get(timeout=10) + assert len(retrieved_transforms.transforms) == 1 + + # Validate that the feedback published an updated matching static transform + # except it's been modified by the scale_factor parameters + transform: TransformStamped = retrieved_transforms.transforms[0] + assert transform.header.frame_id == _TRANSFORM_1.parent + assert transform.header.stamp.sec != 0 + assert np.allclose( + numpify(transform.transform.translation), + numpify(published_feedback.pose.position) * transform_publisher.scale_factor, + ) + assert np.allclose( + numpify(transform.transform.rotation), + numpify(published_feedback.pose.orientation), + ) + + # Validate that the published feedback modified the internal self.transforms + assert ( + transform_publisher.transforms[1].model == _TRANSFORM_2 + ), "No change expected!" + assert np.allclose( + transform_publisher.transforms[0].model.rotation, expected_rotation_1 + ) + transform_data = transform_publisher.transforms_path.read_text() + assert ( + TransformsFile.model_validate_json(transform_data) == _TRANSFORMS_FILE + ), "The file should only be updated on a MOUSE_UP event!" + + # Now publish the exact same feedback, but with a MOUSE_UP event, and validate that + # the file was updated to reflect the new change. + published_feedback.event_type = InteractiveMarkerFeedback.MOUSE_UP + transform_client.feedback.publish(published_feedback) + # Wait for the static transform to get published + transform_client.tf_static.get(timeout=5) + transforms_data = transform_publisher.transforms_path.read_text() + loaded_file = TransformsFile.model_validate_json(transforms_data) + assert loaded_file == TransformsFile( + transforms=[ + TransformModel( + parent=_TRANSFORM_1.parent, + child=_TRANSFORM_1.child, + created_via_api=False, + rotation=tuple(expected_rotation_1.tolist()), + translation=tuple( + (expected_translation_1 * transform_publisher.scale_factor).tolist() + ), + ), + _TRANSFORM_2, + _TRANSFORM_3, + ] + ) + + +def test_publishing_invalid_transform( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """Test that when 'updating' an untracked parent->child, nothing changed""" + # First clear the transforms that were published in the __init__ + for _ in _TRANSFORMS_FILE.transforms: + transform_client.tf_static.get(timeout=5) + + # Now try publishing an invalid transform + file_before = transform_publisher.transforms_path.read_text() + model = TransformModel( + parent="untracked_p", child="untracked_c", created_via_api=False + ) + transform_client.update_transform.publish(model.to_msg(Time())) + with pytest.raises(Empty): + transform_client.tf_static.get(timeout=1) + assert transform_publisher.transforms_path.read_text() == file_before + + +def test_on_update_transform( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """Test that when a valid transform is published, that it is then saved to the file + and republished under tf_static""" + # First clear the transforms that were published in the __init__ + for _ in _TRANSFORMS_FILE.transforms: + transform_client.tf_static.get(timeout=5) + + with pytest.raises(Empty): + transform_client.tf_static.get_nowait() + + published = TransformModel( + parent=_TRANSFORM_1.parent, + child=_TRANSFORM_1.child, + created_via_api=False, + rotation=(0, 1, 2, 3), + translation=(2, 5, 6), + ) + + # Publish the tf static update + transform_client.update_transform.publish(published.to_msg(Time(sec=3234))) + + # Validate the same message was re-published on tf_static + tf_static_publish = transform_client.tf_static.get(timeout=1) + assert tf_static_publish == TFMessage(transforms=[published.to_msg(Time(sec=3234))]) + + # Validate the file was also updated + transforms_data = transform_publisher.transforms_path.read_text() + loaded_file = TransformsFile.model_validate_json(transforms_data) + assert loaded_file == TransformsFile( + transforms=[published, _TRANSFORM_2, _TRANSFORM_3] + ) + + +def test_on_create_transform( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """Test the 'tf_static_create' topic creates only once, and ignores further inputs + for the same transform.""" + + # First clear the transforms that were published in the __init__ + for _ in _TRANSFORMS_FILE.transforms: + transform_client.tf_static.get(timeout=5) + + published = TransformModel( + parent="p", child="c", created_via_api=True, rotation=(0, 1, 2, 3) + ) + + # Publish the tf static creation + transform_client.create_transform.publish(published.to_msg(Time(sec=5819))) + + # This should publish two messages. One with the time == Time() and another with the + # correct time. This behavior might be helpful, I'm not sure. + tf_static_publish = transform_client.tf_static.get(timeout=1) + assert tf_static_publish == TFMessage(transforms=[published.to_msg(Time())]) + + tf_static_publish = transform_client.tf_static.get(timeout=1) + assert tf_static_publish == TFMessage(transforms=[published.to_msg(Time(sec=5819))]) + + # Validate the file was also updated, with created_via_api=True + transforms_data = transform_publisher.transforms_path.read_text() + loaded_file = TransformsFile.model_validate_json(transforms_data) + assert loaded_file == TransformsFile( + transforms=[_TRANSFORM_1, _TRANSFORM_2, _TRANSFORM_3, published] + ) + + # Try publishing again with a different translation, and validate nothing changed + modified_transform = published.model_copy() + modified_transform.translation = (1.0, 5.0, 3.0) + transform_client.create_transform.publish(published.to_msg(Time(sec=10001))) + + # Nothing should have been published + with pytest.raises(Empty): + transform_client.tf_static.get(timeout=0.1) + + # The file shouldn't have changed + assert loaded_file == TransformsFile( + transforms=[_TRANSFORM_1, _TRANSFORM_2, _TRANSFORM_3, published] + ) + + +@pytest.mark.parametrize( + ("transforms_param", "expected_transforms"), + ( + # Test specifying tf that doesn't exist in file + ( + [*_DEFAULT_TRANSFORMS, "new_tf_parent:new_tf_child"], + [ + _TRANSFORM_1, + _TRANSFORM_2, + _TRANSFORM_3, + TransformModel( + parent="new_tf_parent", child="new_tf_child", created_via_api=False + ), + ], + ), + # Test specifying less tf's than exist in file (TF3 should exist because it was + # created by API, not by ros params) + ([_DEFAULT_TRANSFORMS[1]], [_TRANSFORM_2, _TRANSFORM_3]), + # Test exact tfs as those that exist in file + (_DEFAULT_TRANSFORMS, _TRANSFORMS_FILE.transforms), + # Test completely different transforms than what's in the file + ( + ["a_parent:a_child", "amongus:amogus"], + [ + TransformModel( + parent="a_parent", child="a_child", created_via_api=False + ), + TransformModel(parent="amongus", child="amogus", created_via_api=False), + _TRANSFORM_3, + ], + ), + ), +) +def test_state_recovery( + transforms_param: list[str], expected_transforms: list[TransformModel] +) -> None: + """Test that the InteractiveTransformPublisher loads transforms as expected""" + with set_up_interactive_transform_publisher(transforms_param) as publisher: + publisher_transform_models = [t.model for t in publisher.transforms] + publisher_transform_models.sort(key=lambda t: t.marker_name) + expected_transforms.sort(key=lambda t: t.marker_name) + assert publisher_transform_models == expected_transforms + + +def test_on_publish_transform_updates_marker_server_state( + transform_publisher: InteractiveTransformPublisher, + transform_client: TransformClient, +) -> None: + """This test verifies there's no longer a bug in the InteractiveTransformPublisher + where if a transform is modified using `_on_update_transform`, the marker server + won't know about it, so Rviz will show stale transform positions. + """ + published = TransformModel( + parent=_TRANSFORM_1.parent, + child=_TRANSFORM_1.child, + created_via_api=False, + rotation=(0, -0.755, 0.252, 0.606), + translation=(1, 0, 0), + ).to_msg(Time()) + + # First clear the transforms that were published in the __init__ + for _ in _TRANSFORMS_FILE.transforms: + transform_client.tf_static.get(timeout=10) + + # Publish the transform and verify it made it to the other side as expected + transform_client.update_transform.publish(published) + retrieved_transforms = transform_client.tf_static.get(timeout=10).transforms + assert len(retrieved_transforms) == 1 + assert np.isclose( + numpify(retrieved_transforms[0].transform), + numpify(published.transform), + ).all() + + # Ensure that the marker server has the updated transform + marker_name = _TRANSFORM_1.marker_name + context = transform_publisher.interaction_server.marker_contexts[marker_name] + marker: InteractiveMarker = context.int_marker + assert len(transform_publisher.interaction_server.marker_contexts) == 3 + assert np.isclose( + numpify(marker.pose.orientation), numpify(published.transform.rotation) + ).all() + assert np.isclose( + numpify(marker.pose.position), + numpify(published.transform.translation) / transform_publisher.scale_factor, + ).all() diff --git a/pkgs/node_helpers/node_helpers_test/integration/parameters/test_parameters_mixin_dynamic_parameter_updates.py b/pkgs/node_helpers/node_helpers_test/integration/parameters/test_parameters_mixin_dynamic_parameter_updates.py new file mode 100644 index 0000000..0285002 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/parameters/test_parameters_mixin_dynamic_parameter_updates.py @@ -0,0 +1,247 @@ +from collections.abc import Generator +from enum import Enum +from pathlib import Path +from typing import Any + +import pytest +from node_helpers.parameters import ParameterMixin +from node_helpers.testing import set_up_node +from pydantic import BaseModel +from rcl_interfaces.msg import SetParametersResult +from rclpy.node import Node +from rclpy.parameter import Parameter + +PARAMETER_NAME_A = "parameter_a" +PARAMETER_VALUE_A = 1.4 +PARAMETER_NAME_B = "parameter_b" +PARAMETER_VALUE_B = "test" + + +class ParameterNode(Node, ParameterMixin): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("parameters_node", *args, **kwargs) + self.a_parameter_attr = self.declare_and_get_parameter( + name=PARAMETER_NAME_A, type_=float, required=True + ) + self.b_parameter_attr = self.declare_and_get_parameter( + name=PARAMETER_NAME_B, type_=str, required=True + ) + + +@pytest.fixture() +def parameter_node() -> Generator[ParameterNode, None, None]: + yield from set_up_node( + node_class=ParameterNode, + namespace="cool_params", + node_name="parameters_node", + parameter_overrides=[ + Parameter(name=PARAMETER_NAME_A, value=PARAMETER_VALUE_A), + Parameter(name=PARAMETER_NAME_B, value=PARAMETER_VALUE_B), + ], + ) + + +def test_successful_subscribe_attribute_to_updates( + parameter_node: ParameterNode, +) -> None: + assert parameter_node.a_parameter_attr == PARAMETER_VALUE_A + + # Dynamically update the parameter, and assert nothing changes (yet) + parameter_node.set_parameters([Parameter(name=PARAMETER_NAME_A, value=432.619)]) + assert parameter_node.a_parameter_attr == PARAMETER_VALUE_A + + # Subscribe the attribute to updates + parameter_node.subscribe_attribute_to_updates( + "a_parameter_attr", PARAMETER_NAME_A, float + ) + + # Previous changes shouldn't have affected the parameter + assert parameter_node.a_parameter_attr == PARAMETER_VALUE_A + + # Publish a new parameter change, now it should have changed + results: list[SetParametersResult] = parameter_node.set_parameters( + [Parameter(name=PARAMETER_NAME_A, value=1234.5678)] + ) + assert len(results) == 1 + assert results[0].successful + assert parameter_node.a_parameter_attr == 1234.5678 + + +def test_setting_unsubscribed_attribute_returns_failure_message( + parameter_node: ParameterNode, +) -> None: + """Test that if one attribute is subscribed and another one isn't, that if the one + that isn't is attempted to be set, that a failure message will be received. + + Then subscribe both, and try again. + """ + # Subscribe attribute A but don't subscribe attribute B + parameter_node.subscribe_attribute_to_updates( + "a_parameter_attr", PARAMETER_NAME_A, float + ) + + # Try updating both A and B + results = parameter_node.set_parameters( + [ + Parameter(name=PARAMETER_NAME_A, value=5.0), + Parameter(name=PARAMETER_NAME_B, value="cool-string"), + ], + ) + assert len(results) == 2 + + # Check that what should/shouldn't have changed happened + assert parameter_node.a_parameter_attr == 5.0 + assert parameter_node.b_parameter_attr == PARAMETER_VALUE_B + + # Now subscribe B and try again + parameter_node.subscribe_attribute_to_updates( + "b_parameter_attr", PARAMETER_NAME_B, str + ) + + # Try updating both A and B + results = parameter_node.set_parameters( + [ + Parameter(name=PARAMETER_NAME_A, value=10.0), + Parameter(name=PARAMETER_NAME_B, value="cool-string-2"), + ], + ) + assert len(results) == 2 + + # Check that what should/shouldn't have changed happened + assert parameter_node.a_parameter_attr == 10.0 + assert parameter_node.b_parameter_attr == "cool-string-2" + + +def test_setting_same_parameter_twice(parameter_node: ParameterNode) -> None: + """Test that when setting the same parameter twice, the last value is used.""" + parameter_node.subscribe_attribute_to_updates( + "a_parameter_attr", PARAMETER_NAME_A, float + ) + parameter_node.set_parameters( + [ + Parameter(name=PARAMETER_NAME_A, value=100.1), + Parameter(name=PARAMETER_NAME_A, value=100.2), + Parameter(name=PARAMETER_NAME_A, value=100.3), + Parameter(name=PARAMETER_NAME_A, value=100.4), + ] + ) + assert parameter_node.a_parameter_attr == 100.4 + + +def test_invalid_attribute_raises_error(parameter_node: ParameterNode) -> None: + # Success case + parameter_node.subscribe_attribute_to_updates( + "a_parameter_attr", PARAMETER_NAME_A, float + ) + + # Failure case + with pytest.raises(AttributeError): + parameter_node.subscribe_attribute_to_updates( + "nonexistent_attr", PARAMETER_NAME_A, float + ) + + +def test_pydantic_parameter_updating(parameter_node: ParameterNode) -> None: + """Test that parameter updates are received by ros and then updated directly on + the pydantic object. + """ + + class CoolModel(BaseModel): + param_a: str = "default_value_a" + param_b: str = "default_value_b" + + model = parameter_node.declare_from_pydantic_model(CoolModel, "cool_prefix") + assert model.param_a == "default_value_a" + assert model.param_b == "default_value_b" + + # Try updating one value + parameter_node.set_parameters([Parameter(name="cool_prefix.param_a", value="try1")]) + assert model.param_a == "try1" + assert model.param_b == "default_value_b" + + # Try updating the other value + parameter_node.set_parameters([Parameter(name="cool_prefix.param_b", value="try2")]) + assert model.param_a == "try1" + assert model.param_b == "try2" + + # Try updating both values at once + parameter_node.set_parameters( + [ + Parameter(name="cool_prefix.param_a", value="try3a"), + Parameter(name="cool_prefix.param_b", value="try3b"), + ] + ) + assert model.param_a == "try3a" + assert model.param_b == "try3b" + + +def test_nonros_arbitrary_type_parameters_updating( + parameter_node: ParameterNode, +) -> None: + """Test that arbitrary types like Path or Enum are allowed, and when updated they + maintain the desired arbitrary type.""" + + class ExampleEnum(Enum): + OPTION_1 = "option_1" + OPTION_2 = "option_2" + + original_path_value = Path("/a/path/somewhere") + original_enum_value = ExampleEnum.OPTION_2 + + class ArbitraryTypeModel(BaseModel): + some_path: Path = original_path_value + some_enum: ExampleEnum = original_enum_value + + model = parameter_node.declare_from_pydantic_model( + ArbitraryTypeModel, "cool_prefix" + ) + assert model.some_path == original_path_value + assert model.some_enum == original_enum_value + + # Set the path parameter using a string, and validate it was re-converted to a Path + updated_path_raw_value = "/a/new/path/somewhere" + parameter_node.set_parameters( + [Parameter(name="cool_prefix.some_path", value=updated_path_raw_value)] + ) + assert isinstance(model.some_path, Path) + assert model.some_path == Path(updated_path_raw_value) + + # Set the enum parameter using a string, and validate it was re-converted to a Enum + updated_enum_raw_value = "option_1" + parameter_node.set_parameters( + [Parameter(name="cool_prefix.some_enum", value=updated_enum_raw_value)] + ) + assert isinstance(model.some_enum, ExampleEnum) + assert model.some_enum == ExampleEnum(updated_enum_raw_value) + + +def test_subscribe_to_updates_argument_is_respected( + parameter_node: ParameterNode, +) -> None: + """Test that the `subscribe_to_updates` in declare_from_pydantic_model is + respected.""" + + class ChildModel(BaseModel): + third_value: str = "ayy" + fourth_value: str = "bingobongo" + + class SomeModel(BaseModel): + some_value: str = "value" + another_value: int = 3 + + # Test that the subscribe_to_updates argument is recursively passed to children + child: ChildModel + + assert len(parameter_node._on_set_parameters_callbacks) == 1 + + # Test disabling updates + parameter_node.declare_from_pydantic_model( + SomeModel, "some_prefix", subscribe_to_updates=False + ) + assert len(parameter_node._on_set_parameters_callbacks) == 1 + + # Test enabling updates + parameter_node.declare_from_pydantic_model( + SomeModel, "another_prefix", subscribe_to_updates=True + ) + assert len(parameter_node._on_set_parameters_callbacks) == 5 diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/conftest.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/conftest.py new file mode 100644 index 0000000..efe9e7f --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/conftest.py @@ -0,0 +1,134 @@ +from collections.abc import Generator +from threading import Event +from time import sleep +from typing import Any + +import pytest +from node_helpers.robust_rpc import RobustRPCException, RobustRPCMixin +from node_helpers.testing import set_up_node +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.srv import RobustServiceExample +from rclpy.action.server import CancelResponse, ServerGoalHandle +from rclpy.node import Node + + +class RobustCaller(Node, RobustRPCMixin): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("caller", *args, **kwargs) + + self.service_client = self.create_robust_client( + srv_type=RobustServiceExample, srv_name="robust_service" + ) + self.action_client = self.create_robust_action_client( + action_type=RobustActionExample, action_name="robust_action" + ) + + +class RobustImplementer(Node, RobustRPCMixin): + class OhMyAnError(Exception): + pass + + ERROR_MSG = "Oh wow an error message? In _this_ code?? I would never. " + EXPECTED_DATA = "Oh my, cool data that my service/action returned!" + EXPECTED_CANCELLED_DATA = "Oh my, the action was cancelled" + + ROBUST_RPC_ERROR_NAME = "SomeCoolError" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("implementer", *args, **kwargs) + self.raise_error = False + self.raise_robust_rpc_error = False + self.action_cancellable = True + + self.action_reached_event = Event() + """This event is set when the action is called""" + + self.action_continue_event: Event | None = None + """If set, the action will wait until this event is set""" + + self.service_calls = 0 + self.action_calls = 0 + + self.robust_service = self.create_robust_service( + srv_type=RobustServiceExample, + srv_name="robust_service", + callback=self.service_callback, + ) + + self.robust_action_server = self.create_robust_action_server( + action_type=RobustActionExample, + action_name="robust_action", + execute_callback=self.action_callback, + cancel_callback=lambda cancel_request: ( + CancelResponse.ACCEPT + if self.action_cancellable + else CancelResponse.REJECT + ), + ) + + def action_callback(self, goal: ServerGoalHandle) -> RobustActionExample.Result: + self.action_reached_event.set() + self.action_calls += 1 + self._maybe_exception() + + if self.action_continue_event is not None: + # Wait for either the action to be cancelled, or for the event to be set + while not self.action_continue_event.is_set(): + # Check for cancellation requests (requires multithreaded executor) + if goal.is_cancel_requested: + goal.canceled() + return RobustActionExample.Result(data=self.EXPECTED_CANCELLED_DATA) + sleep(0.05) + + goal.succeed() + return RobustActionExample.Result(data=self.EXPECTED_DATA) + + def service_callback( + self, + request: RobustServiceExample.Request, + response: RobustServiceExample.Response, + ) -> RobustServiceExample.Response: + self.service_calls += 1 + self._maybe_exception() + response.data = self.EXPECTED_DATA + return response + + def _maybe_exception(self) -> None: + if self.raise_error: + raise self.OhMyAnError(self.ERROR_MSG) + elif self.raise_robust_rpc_error: + raise RobustRPCException( + error_name=self.ROBUST_RPC_ERROR_NAME, + error_description=self.ERROR_MSG, + message=RobustServiceExample.Request(), + ) + + +@pytest.fixture() +def caller() -> Generator[RobustCaller, None, None]: + yield from set_up_node( + node_class=RobustCaller, + namespace="robust", + node_name="caller", + multi_threaded=False, + ) + + +@pytest.fixture() +def implementer() -> Generator[RobustImplementer, None, None]: + yield from set_up_node( + node_class=RobustImplementer, + namespace="robust", + node_name="implementer", + multi_threaded=False, + ) + + +@pytest.fixture() +def threaded_implementer() -> Generator[RobustImplementer, None, None]: + yield from set_up_node( + node_class=RobustImplementer, + namespace="robust", + node_name="implementer", + multi_threaded=True, + ) diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_action_client.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_action_client.py new file mode 100644 index 0000000..aade4e1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_action_client.py @@ -0,0 +1,37 @@ +from unittest import mock + +import pytest +from node_helpers.robust_rpc import ExecutorNotSetError +from node_helpers.testing import rclpy_context +from node_helpers_msgs.action import RobustActionExample + +from node_helpers_test.integration.robust_rpc.conftest import ( + RobustCaller, + RobustImplementer, +) + + +def test_waits_for_action(caller: RobustCaller, implementer: RobustImplementer) -> None: + """Test that actions are waiting for readiness on the first call""" + with mock.patch.object( + caller.action_client, + "wait_for_server", + wraps=caller.action_client.wait_for_server, + ) as wait_fn: + assert wait_fn.call_count == 0 + + caller.action_client.send_goal_async(goal=RobustActionExample.Goal()) + + assert wait_fn.call_count == 1 + + caller.action_client.send_goal_async(goal=RobustActionExample.Goal()) + assert wait_fn.call_count == 1, "Waiting should only occur on the first call!" + + +def test_fails_without_executor() -> None: + """Test that attempting to call an action without the executor set fails""" + with rclpy_context() as context: + caller = RobustCaller(namespace="something", context=context) + + with pytest.raises(ExecutorNotSetError): + caller.action_client.send_goal_async(RobustActionExample.Goal()) diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_send_goal_as_context.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_send_goal_as_context.py new file mode 100644 index 0000000..eebae68 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_action_client/test_send_goal_as_context.py @@ -0,0 +1,156 @@ +from threading import Event + +import pytest +from action_msgs.msg import GoalStatus +from action_msgs.srv import CancelGoal +from node_helpers.futures import wait_for_future +from node_helpers.robust_rpc import RobustRPCException +from node_helpers.robust_rpc.errors import ActionCancellationRejected +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import ( + RobustActionExample_GetResult_Response, +) + +from ..conftest import RobustCaller, RobustImplementer + + +@pytest.mark.parametrize("wait_for_action", ((True,), (False,))) +def test_cancellation( + threaded_implementer: RobustImplementer, caller: RobustCaller, wait_for_action: bool +) -> None: + """Basic test that the action is cancelled when the context ends""" + + threaded_implementer.action_continue_event = Event() + + with caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15 + ) as handle: + assert handle.accepted + if wait_for_action: + # Wait for the action to be reached + threaded_implementer.action_reached_event.wait(timeout=15) + + assert not threaded_implementer.action_continue_event.is_set() + assert handle.status == GoalStatus.STATUS_CANCELED + assert threaded_implementer.action_calls == 1 + + # Verify that calling get_result twice (after the context already did) is fine, and + # that it yields expected data. + response: RobustActionExample.Result = handle.get_result().result + assert response.data == RobustImplementer.EXPECTED_CANCELLED_DATA + + +def test_cancelling_uncancellable_action_raises_exception( + threaded_implementer: RobustImplementer, caller: RobustCaller +) -> None: + """If an action server rejects cancellation, an exception should be raised""" + threaded_implementer.action_continue_event = Event() + threaded_implementer.action_cancellable = False + + # Happy path (cancel = False and goal can finish) + with caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15, cancel=False + ) as handle: + threaded_implementer.action_continue_event.set() + + assert handle.status == GoalStatus.STATUS_SUCCEEDED + + # Unhappy path (cancel = True) + threaded_implementer.action_continue_event.clear() + with ( + pytest.raises(ActionCancellationRejected), + caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15, cancel=True + ), + ): + pass + + # This should let the called action finish during teardown + threaded_implementer.action_continue_event.set() + + +def test_no_cancellation_if_action_was_already_over( + threaded_implementer: RobustImplementer, caller: RobustCaller +) -> None: + """Make sure that calling get_result within the context works as expected.""" + with caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15 + ) as handle: + # Wait for the result of the action _before_ the end of the context + wait_for_future( + handle.get_result_async(), + type_=RobustActionExample_GetResult_Response, + timeout=15, + ) + assert handle.status == GoalStatus.STATUS_SUCCEEDED + + # Verify no change in status after the context exits + assert handle.status == GoalStatus.STATUS_SUCCEEDED + + +def test_no_cancellation_if_action_was_cancelled_in_context( + threaded_implementer: RobustImplementer, caller: RobustCaller +) -> None: + """Make sure that calling cancel_goal in the context also works as expected""" + threaded_implementer.action_continue_event = Event() + + with caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15 + ) as handle: + # Wait for the result of the action _before_ the end of the context + wait_for_future( + handle.cancel_goal_async(), + type_=CancelGoal.Response, + timeout=15, + ) + assert handle.status in ( + GoalStatus.STATUS_CANCELED, + GoalStatus.STATUS_CANCELING, + ) + + # Verify no change in status after the context exits + assert handle.status == GoalStatus.STATUS_CANCELED + + +def test_exceptions_in_context_action_still_raised( + threaded_implementer: RobustImplementer, caller: RobustCaller +) -> None: + """Make sure exceptions that happen remotely are still raised. + + This basically makes sure that get_result() is being called after the context + exits + """ + threaded_implementer.raise_error = True + + with ( + pytest.raises( + RobustRPCException.like(RobustImplementer.OhMyAnError) + ) as ex_info, + caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=15 + ) as handle, + ): + pass + + assert ex_info.value.message.status == GoalStatus.STATUS_ABORTED + assert handle.status == GoalStatus.STATUS_ABORTED + + +def test_exception_in_context_cause_cancellation( + threaded_implementer: RobustImplementer, caller: RobustCaller +) -> None: + """Test that if an exception is raised in the context, that the underlying action + is still cancelled + + This ensures that the function uses a `try: finally:` pattern + """ + threaded_implementer.action_continue_event = Event() + + with pytest.raises(OSError): + with caller.action_client.send_goal_as_context( + RobustActionExample.Goal(), timeout=3 + ) as handle: + # By not setting the action_continue_event, a timeout is guaranteed + raise OSError("Oh no, whaaat, some unexpected error?") + + assert handle.status == GoalStatus.STATUS_CANCELED diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_mixin.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_mixin.py new file mode 100644 index 0000000..5599aa1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_mixin.py @@ -0,0 +1,213 @@ +from collections.abc import Callable +from typing import Any, cast +from unittest import mock + +import pytest +from action_msgs.msg import GoalStatus +from node_helpers.futures import run_action_with_timeout, wait_for_future +from node_helpers.robust_rpc import mixin +from node_helpers.robust_rpc.errors import InvalidRobustMessage, RobustRPCException +from node_helpers.robust_rpc.schema import validate_robust_message +from node_helpers.robust_rpc.typing import ( + RequestType, + ResponseType, + RobustActionMsg, + RobustServiceMsg, +) +from node_helpers_msgs.action import RobustActionExample +from node_helpers_msgs.action._robust_action_example import ( + RobustActionExample_GetResult_Response, +) +from node_helpers_msgs.srv import RobustServiceExample +from rclpy.action.server import ServerGoalHandle +from sensor_msgs.srv import SetCameraInfo +from visualization_msgs.srv import GetInteractiveMarkers + +from .conftest import RobustCaller, RobustImplementer + + +def _validate_successful_message( + msg: RobustServiceExample.Response | RobustActionExample.Result, +) -> None: + assert msg.data == RobustImplementer.EXPECTED_DATA + assert msg.error_name == "" + assert msg.error_description == "" + + +def _validate_exception( + ex_info: RobustRPCException, + expected_error_type: type[RobustRPCException], + expected_message_type: type[ResponseType], +) -> None: + """Validate that the exception passed along the expected values""" + # Validate error attributes + assert ex_info.__class__.__name__ == ex_info.error_name + assert type(ex_info) is RobustRPCException.like( + ex_info.error_name + ), "These must be exactly the same error object" + assert isinstance(ex_info, RobustRPCException) + assert isinstance(ex_info, expected_error_type) + + assert ex_info.error_name == RobustImplementer.OhMyAnError.__name__ + assert ex_info.error_description == RobustImplementer.ERROR_MSG + + # Validate the error message also includes the underlying exception information + assert RobustImplementer.ERROR_MSG in str( + ex_info + ), "The error message should be passed in the error message!" + assert RobustImplementer.OhMyAnError.__name__ in str( + ex_info + ), "The original error name should be passed in the error message" + + +def test_robust_service_exception( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Test a service with a callback exception propagates to the client""" + implementer.raise_error = True + future = caller.service_client.call_async(request=RobustServiceExample.Request()) + + expected_error_type = RobustRPCException.like(RobustImplementer.OhMyAnError) + with pytest.raises(expected_error_type) as ex_info: + wait_for_future(future, RobustServiceExample.Response, timeout=30) + assert implementer.service_calls == 1 and implementer.action_calls == 0 + _validate_exception( + ex_info.value, expected_error_type, RobustServiceExample.Response + ) + + +def test_robust_service_success( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Test a service without a callback exception works like normal""" + + implementer.raise_error = False + future = caller.service_client.call_async(request=RobustServiceExample.Request()) + response = wait_for_future(future, RobustServiceExample.Response, timeout=30) + + assert implementer.service_calls == 1 and implementer.action_calls == 0 + _validate_successful_message(response) + + +def test_robust_action_exception( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Test an action with a callback exception propagates to the client""" + + implementer.raise_error = True + + expected_error_type = RobustRPCException.like(RobustImplementer.OhMyAnError) + with pytest.raises(expected_error_type) as ex_info: + run_action_with_timeout( + caller.action_client, + RobustActionExample.Goal(), + RobustActionExample_GetResult_Response, + ) + exception_object = ex_info.value + + assert implementer.service_calls == 0 and implementer.action_calls == 1 + assert ( + exception_object.message.status == GoalStatus.STATUS_ABORTED + ), "Actions that fail in errors should have the goal status set to aborted!" + _validate_exception( + exception_object, expected_error_type, RobustActionExample_GetResult_Response + ) + + +def test_robust_action_success( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Test an action without a callback exception works like normal""" + + implementer.raise_error = False + response = run_action_with_timeout( + caller.action_client, + RobustActionExample.Goal(), + RobustActionExample_GetResult_Response, + ) + + assert response.status is GoalStatus.STATUS_SUCCEEDED + assert implementer.service_calls == 0 and implementer.action_calls == 1 + _validate_successful_message(response.result) + + +@pytest.mark.parametrize( + ("message", "expect_error", "is_action"), + ( + # Test a valid service + (RobustServiceExample, False, False), + # Test invalid service(s) + (GetInteractiveMarkers, True, False), + (SetCameraInfo, True, False), + # Test a valid action + (RobustActionExample, False, True), + ), +) +def test_messages_are_checked_for_validity( + caller: RobustCaller, + message: type[RobustActionMsg] | type[RobustServiceMsg], + expect_error: bool, + is_action: bool, +) -> None: + """Test that creating a service or action checks the message for fields required""" + test_fns: list[Callable[[], Any]] = [] + + if is_action: + + def fake_action_callback(goal_handle: ServerGoalHandle) -> ResponseType: + return cast(ResponseType, None) + + action_msg = cast(type[RobustActionMsg], message) + test_fns.append(lambda: caller.create_robust_action_client(action_msg, "name")) + test_fns.append( + lambda: caller.create_robust_action_server( + action_msg, "name", fake_action_callback + ) + ) + else: + + def fake_service_callback(a: RequestType, b: ResponseType) -> ResponseType: + return cast(ResponseType, None) + + service_msg = cast(type[RobustServiceMsg], message) + test_fns.append(lambda: caller.create_robust_client(service_msg, "name")) + test_fns.append( + lambda: caller.create_robust_service( + service_msg, "name", fake_service_callback + ) + ) + + with mock.patch( + f"{mixin.__name__}.validate_robust_message", wraps=validate_robust_message + ) as validate: + for fn in test_fns: + if expect_error: + with pytest.raises(InvalidRobustMessage): + fn() + else: + fn() + assert validate.call_count == len( + test_fns + ), "The validate call should be called for all functions, always!" + + +def test_robust_rpc_error_passthrough( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Tests that, if an action server raises a RobustRPCError, the error name and + description are passed through to the caller unchanged. In practice, this tends to + come up when one action calls another action which raises an exception. + """ + + implementer.raise_robust_rpc_error = True + expected_error_type = RobustRPCException.like(implementer.ROBUST_RPC_ERROR_NAME) + with pytest.raises(expected_error_type) as ex_info: + run_action_with_timeout( + caller.action_client, + RobustActionExample.Goal(), + RobustActionExample_GetResult_Response, + ) + exception_object = ex_info.value + + assert exception_object.error_name == implementer.ROBUST_RPC_ERROR_NAME + assert exception_object.error_description == implementer.ERROR_MSG diff --git a/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_service_client.py b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_service_client.py new file mode 100644 index 0000000..1706ebb --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/robust_rpc/test_service_client.py @@ -0,0 +1,37 @@ +from unittest import mock + +import pytest +from node_helpers.robust_rpc import ExecutorNotSetError +from node_helpers.testing import rclpy_context +from node_helpers_msgs.srv import RobustServiceExample + +from .conftest import RobustCaller, RobustImplementer + + +def test_waits_for_service( + caller: RobustCaller, implementer: RobustImplementer +) -> None: + """Test that actions are waiting for readiness on the first call""" + with mock.patch.object( + caller.service_client, + "wait_for_service", + wraps=caller.service_client.wait_for_service, + ) as wait_fn: + assert wait_fn.call_count == 0 + + caller.service_client.call(RobustServiceExample.Request()) + + assert wait_fn.call_count == 1 + + caller.service_client.call(RobustServiceExample.Request()) + + assert wait_fn.call_count == 1, "Waiting should only occur on the first call!" + + +def test_fails_without_executor() -> None: + """Test that attempting to call a service without the executor set fails""" + with rclpy_context() as context: + caller = RobustCaller(namespace="whatever", context=context) + + with pytest.raises(ExecutorNotSetError): + caller.service_client.call_async(RobustServiceExample.Request()) diff --git a/pkgs/node_helpers/node_helpers_test/integration/sensors/__init__.py b/pkgs/node_helpers/node_helpers_test/integration/sensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/integration/sensors/conftest.py b/pkgs/node_helpers/node_helpers_test/integration/sensors/conftest.py new file mode 100644 index 0000000..63adb25 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/sensors/conftest.py @@ -0,0 +1,62 @@ +from collections.abc import Generator +from typing import Any + +import pytest +from node_helpers.nodes import HelpfulNode +from node_helpers.sensors import BaseSensorBuffer, BaseSensorPublisher +from node_helpers.sensors.base_publisher import DEFAULT_SENSORS_VIS_TOPIC, SENSOR_MSG +from node_helpers.testing import NodeForTesting, set_up_node +from node_helpers_msgs.msg import SensorExample +from rclpy.qos import qos_profile_services_default +from visualization_msgs.msg import Marker, MarkerArray + +SENSOR_FRAME_ID = "cool-frame" +SENSOR_TOPIC = "/example_sensor" + + +class ExamplePublisher( + BaseSensorPublisher[SensorExample, int, BaseSensorPublisher.Parameters] +): + def to_rviz_msg(self, msg: SENSOR_MSG) -> list[Marker]: + return [Marker(), Marker()] + + +class ExampleBuffer(BaseSensorBuffer[SensorExample]): + pass + + +class ExamplePubNode(HelpfulNode): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("publisher", *args, **kwargs) + self.publisher = ExamplePublisher( + node=self, + msg_type=SensorExample, + parameters=ExamplePublisher.Parameters( + frame_id=SENSOR_FRAME_ID, sensor_topic=SENSOR_TOPIC + ), + sensor_qos=qos_profile_services_default, + vis_qos=qos_profile_services_default, + ) + + +class ExampleSubNode(NodeForTesting): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("subscriber", *args, **kwargs) + self.subscriber = ExampleBuffer(self, SensorExample, SENSOR_TOPIC) + self.marker_queue = self.create_queue_subscription( + MarkerArray, DEFAULT_SENSORS_VIS_TOPIC + ) + + +@pytest.fixture() +def publisher_node() -> Generator[ExamplePubNode, None, None]: + yield from set_up_node( + node_class=ExamplePubNode, namespace="publish", node_name="publisher" + ) + + +@pytest.fixture() +def subscriber_node() -> Generator[ExampleSubNode, None, None]: + yield from set_up_node( + node_class=ExampleSubNode, namespace="subscribe", node_name="subscriber" + ) diff --git a/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_buffer.py b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_buffer.py new file mode 100644 index 0000000..3ba17d3 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_buffer.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from queue import Empty, Queue +from threading import Event + +import pytest +from builtin_interfaces.msg import Time +from node_helpers.testing import DynamicContextThread +from node_helpers_msgs.msg import SensorExample +from std_msgs.msg import Header + +from .conftest import SENSOR_FRAME_ID, ExamplePubNode, ExampleSubNode + +EXAMPLE_MSG = SensorExample( + header=Header(frame_id=SENSOR_FRAME_ID, stamp=Time(sec=1, nanosec=1)), value=5 +) + + +def test_on_value_change( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """Test the ON_VALUE_CHANGE callback only calls when the value changes""" + msg = deepcopy(EXAMPLE_MSG) + + # Subscribe a queue to all changes in value + on_value_change: Queue[int] = Queue() + subscriber_node.subscriber.on_value_change.subscribe(on_value_change.put) + + # The first publish should always count as a value "change" + publisher_node.publisher.publish_sensor(msg) + received_msg = on_value_change.get(timeout=0.5) + assert received_msg == msg + + # Publish the same value, and expect no changes + publisher_node.publisher.publish_sensor(msg) + with pytest.raises(Empty): + on_value_change.get(timeout=0.5) + + # Change the value and validate the callback was called again + msg.value = 1337 + publisher_node.publisher.publish_sensor(msg) + received_msg = on_value_change.get(timeout=0.5) + assert received_msg == msg + assert on_value_change.qsize() == 0 + + +def test_on_receive( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """Test the ON_RECEIVE callback calls for each single publish""" + msg = deepcopy(EXAMPLE_MSG) + + # Subscribe a queue to all publishes, regardless of whether it's changed + on_receive: Queue[int] = Queue() + subscriber_node.subscriber.on_receive.subscribe(on_receive.put) + + # Try publishing and receiving a few times + for _ in range(3): + publisher_node.publisher.publish_sensor(msg) + received_msg = on_receive.get(timeout=0.5) + assert received_msg == msg + + assert on_receive.qsize() == 0 + + +def test_get_timeouts( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """Test the 'get' timeout functionality""" + # When no value has ever been received + with pytest.raises(TimeoutError): + subscriber_node.subscriber.get(timeout=0.1) + + # When there is some value, but no 'after' parameter is set, the message should be + # immediately returned + msg = deepcopy(EXAMPLE_MSG) + subscriber_node.subscriber._latest_reading = msg + assert subscriber_node.subscriber.get(timeout=0.0) == msg + + # If a message is requested with an "after" that is greater than the current message + # then it should block until timeout + with pytest.raises(TimeoutError): + subscriber_node.subscriber.get(after=Time(sec=99999999), timeout=0.1) + # Make sure there's no bugs when timeout=0.0 + with pytest.raises(TimeoutError): + subscriber_node.subscriber.get(after=Time(sec=99999999), timeout=0.0) + + # If a message is requested with an "after" that is < current message, it should + # return the last message immediately, regardless of timeout + assert ( + subscriber_node.subscriber.get(after=Time(sec=0, nanosec=1), timeout=0.0) == msg + ) + + +def test_get_after( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """Test the 'get' functionality when requesting a timestamp greater than current""" + on_getting = Event() + on_gotten = Event() + after_stamp = Time(sec=3, nanosec=0) + msg = deepcopy(EXAMPLE_MSG) + after_msg = deepcopy(msg) + after_msg.header.stamp = after_stamp + + def wait_for_get() -> None: + on_getting.set() + assert subscriber_node.subscriber.get(after=after_stamp) == after_msg + on_gotten.set() + + with DynamicContextThread(target=wait_for_get): + # Wait for the system to be 'getting' + assert on_getting.wait(10) + # The "get" should subscribe to changes + while len(subscriber_node.subscriber.on_receive._subscribers) == 0: + pass + assert not on_gotten.is_set() + + # Now publish a few changes that are not past the requested number + for i in range(3): + msg.header.stamp.nanosec = i + publisher_node.publisher.publish_sensor(msg) + assert not on_gotten.wait(0.5) + + # Now "publish" a change with an 'after' that is past the requested number + publisher_node.publisher.publish_sensor(after_msg) + assert on_gotten.wait(0.5) + + +def test_reject_out_of_order_messages(subscriber_node: ExampleSubNode) -> None: + # Create two messages with different timestamps + newer_msg = deepcopy(EXAMPLE_MSG) + newer_msg.header.stamp = Time(sec=2, nanosec=0) + + older_msg = deepcopy(EXAMPLE_MSG) + older_msg.header.stamp = Time(sec=1, nanosec=0) + + # Publish a newer message and then an older one + subscriber_node.subscriber._on_receive(newer_msg) + subscriber_node.subscriber._on_receive(older_msg) + + # Check if the older message was rejected + assert subscriber_node.subscriber._latest_reading == newer_msg diff --git a/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_publisher.py b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_publisher.py new file mode 100644 index 0000000..45c35f8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_base_publisher.py @@ -0,0 +1,123 @@ +import math +from copy import deepcopy +from unittest import mock + +import pytest +from builtin_interfaces.msg import Time +from node_helpers_msgs.msg import SensorExample +from std_msgs.msg import Header +from visualization_msgs.msg import Marker + +from .conftest import ( + SENSOR_FRAME_ID, + SENSOR_TOPIC, + ExamplePublisher, + ExamplePubNode, + ExampleSubNode, +) + + +def test_publish_sensor( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """A full end-to-end test of publishing and receiving values. + Ensures: + - Sensor value is published + - A marker array is generated, published, and assigned IDs and a namespace + """ + msg = SensorExample( + header=Header( + frame_id=publisher_node.publisher.frame_id, stamp=Time(sec=1, nanosec=1) + ), + value=5, + ) + publisher_node.publisher.publish_sensor(msg) + + # Verify the received message + sensor_msg_received = subscriber_node.subscriber.get(timeout=5) + assert sensor_msg_received == msg + + # Verify the visualization messages got generated and published as expected + markers: list[Marker] = subscriber_node.marker_queue.get(timeout=5).markers + assert len(markers) == 3 + assert markers[0].action == Marker.DELETEALL + for index, marker in enumerate(markers[1:]): + assert marker.header == msg.header, "Markers should share the sensor header" + assert marker.id == index + 1 + + +def test_publish_sensor_sanitization(publisher_node: ExamplePubNode) -> None: + """Publishing should check the header for validity""" + working_msg = SensorExample( + header=Header( + frame_id=publisher_node.publisher.frame_id, stamp=Time(sec=1, nanosec=1) + ), + value=5, + ) + + # This should work + publisher_node.publisher.publish_sensor(working_msg) + + # Test publishing a header that is not relevant to the sensor + bad_frame_id = deepcopy(working_msg) + bad_frame_id.header.frame_id = "not the assigned frame id" + with pytest.raises(ValueError): + publisher_node.publisher.publish_sensor(bad_frame_id) + + # Test publishing an unset timestamp + bad_stamp = deepcopy(working_msg) + bad_stamp.header.stamp = Time() + with pytest.raises(ValueError): + publisher_node.publisher.publish_sensor(bad_stamp) + + # Validate that markers with pre-assigned headers aren't allowed + with mock.patch.object(publisher_node.publisher, "to_rviz_msg") as to_rviz_msg: + to_rviz_msg.return_value = [Marker(header=Header(frame_id="some preset frame"))] + + with pytest.raises(ValueError): + publisher_node.publisher.publish_sensor(working_msg) + + +def test_publish_value(publisher_node: ExamplePubNode) -> None: + """Test that publishing a value uses a current timestamp with the requisite + frame_id""" + + # Test publishing without an integer + expected = 1231 + with mock.patch.object( + publisher_node.publisher._sensor_publisher, "publish" + ) as ros_publish: + publisher_node.publisher.publish_value(expected) + + # Verify the frame ID and a timestamp was assigned + assert ros_publish.call_count == 1 + called_with: SensorExample = ros_publish.call_args[0][0] + assert called_with.header.frame_id == publisher_node.publisher.frame_id + assert called_with.header.stamp.sec != 0 + + +def test_throttling_functionality_has_ttl_caches_assigned( + publisher_node: ExamplePubNode, subscriber_node: ExampleSubNode +) -> None: + """Verify that throttled visualization / sensors have ttl caches applied""" + # Create a throttled visualizer + throttled_publisher = ExamplePublisher( + node=publisher_node, + msg_type=SensorExample, + parameters=ExamplePublisher.Parameters( + vis_publishing_max_hz=12.0, + sensor_publishing_max_hz=6.0, + frame_id=SENSOR_FRAME_ID, + sensor_topic=SENSOR_TOPIC, + ), + ) + + throttled_vis_publish_fn = throttled_publisher._publish_rviz_markers + normal_vis_publish_fn = publisher_node.publisher._publish_rviz_markers + assert math.isclose(throttled_vis_publish_fn.cache.ttl_seconds, 1 / 12) # type: ignore + assert not hasattr(normal_vis_publish_fn, "cache") + + throttled_sensor_publish_fn = throttled_publisher.publish_sensor + normal_sensor_publish_fn = publisher_node.publisher.publish_sensor + assert math.isclose(throttled_sensor_publish_fn.cache.ttl_seconds, 1 / 6) # type: ignore + assert not hasattr(normal_sensor_publish_fn, "cache") diff --git a/pkgs/node_helpers/node_helpers_test/integration/sensors/test_binary_signal.py b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_binary_signal.py new file mode 100644 index 0000000..d8318c1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/sensors/test_binary_signal.py @@ -0,0 +1,85 @@ +from collections.abc import Generator +from typing import Any + +import pytest +from builtin_interfaces.msg import Time +from geometry_msgs.msg import Vector3 +from node_helpers.sensors.binary_signal import BinarySensorFromRangeFinder +from node_helpers.sensors.rangefinder import RangefinderBuffer +from node_helpers.testing import NodeForTesting, set_up_node +from node_helpers_msgs.msg import BinaryReading, RangefinderReading +from std_msgs.msg import Header + + +class ClientNode(NodeForTesting): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs, node_name="node") + + self.binary_queue = self.create_queue_subscription( + BinaryReading, "binary_topic" + ) + self.rangefinder_queue = self.create_queue_subscription( + RangefinderReading, "rangefinder_topic" + ) + self.rangefinder_publisher = self.create_publisher( + RangefinderReading, "rangefinder_topic", 10 + ) + + +@pytest.fixture() +def node() -> Generator[ClientNode, None, None]: + yield from set_up_node( + node_class=ClientNode, + namespace="test_namespace", + node_name="test_node", + multi_threaded=True, + ) + + +def test_binary_sensor_from_rangefinder(node: ClientNode) -> None: + """Test that the BinarySensorFromRangeFinder updates the binary state correctly""" + rangefinder_buffer = RangefinderBuffer( + node=node, + sensor_topic="rangefinder_topic", + callback_group=None, + ) + parameters = BinarySensorFromRangeFinder.Parameters( + sensor_topic="binary_topic", + vis_topic="visualization_topic", + frame_id="test_frame", + threshold=1.0, + inverted=False, + ) + binary_sensor = BinarySensorFromRangeFinder( + node=node, + rangefinder=rangefinder_buffer, + parameters=parameters, + ) + + # Simulate a rangefinder reading below the threshold + reading_below_threshold = RangefinderReading( + header=Header(frame_id="test_frame", stamp=Time(sec=3)), + value=Vector3(x=0.0, y=0.0, z=0.5), + ) + + assert node.binary_queue.qsize() == 0 + node.rangefinder_publisher.publish(reading_below_threshold) + + # Validate that the binary sensor is triggered + assert node.binary_queue.get(timeout=5.0).value is True + + # Simulate a rangefinder reading above the threshold + reading_above_threshold = RangefinderReading( + header=Header(frame_id="test_frame", stamp=Time(sec=3)), + value=Vector3(x=0.0, y=0.0, z=1.5), + ) + node.rangefinder_publisher.publish(reading_above_threshold) + + # Validate that the binary sensor is not triggered + assert node.binary_queue.get(timeout=5.0).value is False + + # Simulate a rangefinder reading below the threshold again + node.rangefinder_publisher.publish(reading_below_threshold) + + # Validate that the binary sensor is triggered again + assert node.binary_queue.get(timeout=5.0).value is True diff --git a/pkgs/node_helpers/node_helpers_test/integration/tf/test_constant_broadcaster.py b/pkgs/node_helpers/node_helpers_test/integration/tf/test_constant_broadcaster.py new file mode 100644 index 0000000..5952166 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/tf/test_constant_broadcaster.py @@ -0,0 +1,59 @@ +from collections.abc import Generator +from time import sleep +from typing import Any + +import pytest +from geometry_msgs.msg import Transform, TransformStamped +from node_helpers.testing import NodeForTesting, set_up_node +from node_helpers.tf import ConstantStaticTransformBroadcaster +from node_helpers.timing import TestingTimeout as Timeout +from std_msgs.msg import Header +from tf2_msgs.msg import TFMessage + + +class ClientNode(NodeForTesting): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs, node_name="node") + + self.tf_queue = self.create_queue_subscription(TFMessage, "/tf_static") + + +@pytest.fixture() +def node() -> Generator[ClientNode, None, None]: + yield from set_up_node( + node_class=ClientNode, + namespace="ayylaymayo", + node_name="nodeynode", + multi_threaded=True, + ) + + +def test_basic_operation(node: ClientNode) -> None: + """Test that the broadcaster broadcasts regularly and allows updating the TF""" + tf_1 = TransformStamped(transform=Transform(), header=Header(frame_id="first")) + tf_2 = TransformStamped(transform=Transform(), header=Header(frame_id="second")) + + broadcaster = ConstantStaticTransformBroadcaster( + node=node, publish_seconds=0.1, initial_transform=tf_1 + ) + + tf_1_expected = TFMessage(transforms=[tf_1]) + tf_2_expected = TFMessage(transforms=[tf_2]) + + # Validate that the publisher constantly publishes + assert node.tf_queue.get(timeout=5.0) == tf_1_expected + assert node.tf_queue.get(timeout=5.0) == tf_1_expected + assert node.tf_queue.get(timeout=5.0) == tf_1_expected + + # Set a new transform + broadcaster.set_transform(tf_2) + + # Wait until it is broadcasting the new transform + timeout = Timeout(seconds=25) + while timeout and node.tf_queue.get(timeout=5.0) == tf_1_expected: + sleep(0.01) + + # The new transform should now be being published + assert node.tf_queue.get(timeout=5.0) == tf_2_expected + assert node.tf_queue.get(timeout=5.0) == tf_2_expected + assert node.tf_queue.get(timeout=5.0) == tf_2_expected diff --git a/pkgs/node_helpers/node_helpers_test/integration/timing/test_singleshots_mixin.py b/pkgs/node_helpers/node_helpers_test/integration/timing/test_singleshots_mixin.py new file mode 100644 index 0000000..6208327 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/timing/test_singleshots_mixin.py @@ -0,0 +1,146 @@ +import gc +from collections.abc import Generator +from queue import Empty, Queue +from time import sleep +from typing import Any, Literal + +import pytest +from node_helpers.testing import set_up_node +from node_helpers.timing import SingleShotMixin, Timeout +from rclpy.exceptions import InvalidHandle +from rclpy.node import Node + + +class SingleShotNode(Node, SingleShotMixin): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("single_shot_node", *args, **kwargs) + + +@pytest.fixture(params=[False, True]) +def single_shot_node( + request: pytest.FixtureRequest, +) -> Generator[SingleShotNode, None, None]: + """Yields two versions of the single shot node: multi-threaded, single-threaded""" + yield from set_up_node( + node_class=SingleShotNode, + namespace="lol_idk", + node_name="single_shot_node", + parameter_overrides=[], + multi_threaded=request.param, + ) + + +def test_single_shots_run_once(single_shot_node: SingleShotNode) -> None: + """Basic test that when a single shot is created, yields only one response""" + results: Queue[Literal[1]] = Queue() + single_shot_node.create_single_shot_timer( + timer_period_sec=0, + callback=lambda: results.put(1), + ) + + assert results.get(timeout=0.2) == 1 + with pytest.raises(Empty): + results.get(timeout=0.2) + + +def test_single_shots_return_in_order(single_shot_node: SingleShotNode) -> None: + """This verifies single shots are run 'in order' of their time_period_sec by + creating multiple single shots incrementally going up in time, then verifying the + order returned correctly""" + + results: Queue[int] = Queue() + n_single_shots = 10 + time_offset = 0.3 + + # Create several timers in reverse order to make sure only the time_period_sec is + # the only thing affecting order + for i in range(n_single_shots - 1, -1, -1): + single_shot_node.create_single_shot_timer( + timer_period_sec=i * time_offset, + callback=lambda i=i: results.put(i), # type: ignore + ) + + # Now wait for all the single shots and validate they returned in order + for i in range(n_single_shots): + assert results.get() == i + + # There should not be anything else + with pytest.raises(Empty): + results.get(timeout=0.1) + + +def test_timers_cancelled_and_dereferenced(single_shot_node: SingleShotNode) -> None: + """Test that creating many single shots works fine and robustly, and ends up without + more references than we began with""" + + results: Queue[None] = Queue(maxsize=1) + n_single_shots = 1000 + + timers = tuple( + single_shot_node.create_single_shot_timer( + timer_period_sec=0, + callback=lambda: results.put(None, timeout=30), + ) + for _ in range(n_single_shots) + ) + + # Get the number of referrants to the timer objects before allowing their callbacks + # to complete + referrers_count_before = len(gc.get_referrers(*timers)) + task_count_before = len(single_shot_node.executor._tasks) + assert task_count_before > 0, "Some tasks should be added for timers to work" + + # Now wait for all the single shots, allowing the callbacks to finish (because the + # queue has a maxsize=1, it will block the timer callbacks until get() is called) + for _ in range(n_single_shots): + results.get(timeout=30) + + # Validate the timer references have dropped by n_single_shots and no more tasks + # exist in the executor. This can take a moment when running with a multithreaded + # executor, so we run this with a timeout. + timeout = Timeout(seconds=10) + + def ros_has_dropped_references() -> bool: + current_n_referrers = len(gc.get_referrers(*timers)) + return referrers_count_before - current_n_referrers >= n_single_shots - 1 + + while not ros_has_dropped_references() and timeout: + # Garbage collect so that `__del__` methods in the timer get triggered + gc.collect() + # Wake the executor to let it know that something has changed + single_shot_node.executor.wake() + sleep(0.1) + + task_count_after = len(single_shot_node.executor._tasks) + assert results.qsize() == 0 + assert task_count_after <= 1, "Only the 'wake()' task (or no tasks) should be left!" + + # Validate all timers were destroyed + for timer in timers: + with pytest.raises(InvalidHandle): + timer.is_ready() + + +def test_timers_creation_allowed_before_executor() -> None: + """Test that you can safely call create_single_shot_timer in a node __init__ and it + will call as soon as the node is spun up with an executor""" + + requests: Queue[None] = Queue() + + class SingleShotOnInitNode(Node, SingleShotMixin): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__("single_shot_node", *args, **kwargs) + self.create_single_shot_timer( + timer_period_sec=0, callback=lambda: requests.put(None) + ) + + # The following will run the node __init__ then assign it an executor. This should + # work just fine. + node_generator = set_up_node(SingleShotOnInitNode, "node_namespace", "node_name") + next(node_generator) + + # Ensure a request is received + requests.get(timeout=15) + + # Shut down the node + tuple(node_generator) diff --git a/pkgs/node_helpers/node_helpers_test/integration/timing/test_ttl_caches.py b/pkgs/node_helpers/node_helpers_test/integration/timing/test_ttl_caches.py new file mode 100644 index 0000000..0e387c1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/timing/test_ttl_caches.py @@ -0,0 +1,114 @@ +from typing import cast +from unittest.mock import Mock + +from node_helpers.timing import ttl_cached + + +def test_basic_use_on_bound_methods() -> None: + ttl = 0.25 + + class SomeClass: + return_value = 3.0 + + @ttl_cached(seconds=ttl) + def some_fn(self) -> float: + return self.return_value + + some_instance = SomeClass() + timer = Mock() + some_instance.some_fn.cache._get_time = timer + + timer.return_value = 0 + assert some_instance.some_fn() == 3.0 + assert some_instance.some_fn() == 3.0 + some_instance.return_value = 4.0 + assert some_instance.some_fn() == 3.0 + assert some_instance.some_fn() == 3.0 + timer.return_value = ttl + 0.1 + assert some_instance.some_fn() == 4.0 + + # Also try busy looping continuously while waiting for a value to change + some_instance.return_value = 5.0 + assert some_instance.some_fn() == 4.0 + while some_instance.some_fn() == 4.0: + timer.return_value += 0.1 + assert some_instance.some_fn() == 5.0 + + +def test_basic_use_on_unbound_methods() -> None: + ttl = 0.25 + + value_holder = Mock() + value_holder.return_value = 3.0 + + @ttl_cached(seconds=ttl) + def some_fn() -> float: + return cast(float, value_holder()) + + timer = Mock() + some_fn.cache._get_time = timer + timer.return_value = 0 + + assert some_fn() == 3.0 + assert some_fn() == 3.0 + value_holder.return_value = 4.0 + assert some_fn() == 3.0 + assert some_fn() == 3.0 + timer.return_value = ttl + 0.1 + assert some_fn() == 4.0 + + # Also try busy looping continuously while waiting for a value to change + value_holder.return_value = 5.0 + assert some_fn() == 4.0 + while some_fn() == 4.0: + timer.return_value += 0.1 + assert some_fn() == 5.0 + + +def test_not_called_even_when_returns_none() -> None: + """Test that the TTL functionality respects the None return value""" + ttl = 0.25 + underlying_mock = Mock() + + class SomeClass: + @ttl_cached(seconds=ttl) + def some_fn(self) -> None: + underlying_mock() + + some_instance = SomeClass() + + timer = Mock() + some_instance.some_fn.cache._get_time = timer + timer.return_value = 0 + + assert underlying_mock.call_count == 0 + some_instance.some_fn() + assert underlying_mock.call_count == 1 + some_instance.some_fn() + assert underlying_mock.call_count == 1 + timer.return_value = ttl + 0.1 + some_instance.some_fn() + assert underlying_mock.call_count == 2 + some_instance.some_fn() + assert underlying_mock.call_count == 2 + + +def test_ttl_editing() -> None: + """Test editing the ttl property on a method""" + + class SomeClass: + return_value = 3.0 + + @ttl_cached(seconds=999.0) + def some_fn(self) -> float: + return self.return_value + + some_instance = SomeClass() + + assert some_instance.some_fn() == 3.0 + assert some_instance.some_fn() == 3.0 + some_instance.return_value = 4.0 + assert some_instance.some_fn() == 3.0 + assert some_instance.some_fn() == 3.0 + some_instance.some_fn.cache.ttl_seconds = 0.0 + assert some_instance.some_fn() == 4.0 diff --git a/pkgs/node_helpers/node_helpers_test/integration/topics/test_latching_publisher.py b/pkgs/node_helpers/node_helpers_test/integration/topics/test_latching_publisher.py new file mode 100644 index 0000000..00dbf0b --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/integration/topics/test_latching_publisher.py @@ -0,0 +1,64 @@ +from collections.abc import Generator +from time import sleep +from typing import Any + +import pytest +from node_helpers.testing import NodeForTesting, set_up_node +from node_helpers.timing import TestingTimeout as Timeout +from node_helpers.topics.latching_publisher import LatchingPublisher +from std_msgs.msg import String + + +class ClientNode(NodeForTesting): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs, node_name="node") + + self.msg_queue = self.create_queue_subscription(String, "test_topic") + + +@pytest.fixture() +def node() -> Generator[ClientNode, None, None]: + yield from set_up_node( + node_class=ClientNode, + namespace="test_namespace", + node_name="test_node", + multi_threaded=True, + ) + + +def test_publish_message(node: ClientNode) -> None: + """Test that the LatchingPublisher publishes messages correctly""" + publisher = LatchingPublisher(node, String, "test_topic", republish_delay=0.1) + test_message = String(data="Hello, world!") + + publisher(test_message) + + # Validate that the publisher publishes the message + assert node.msg_queue.get(timeout=5.0).data == "Hello, world!" + assert node.msg_queue.get(timeout=5.0).data == "Hello, world!" + assert node.msg_queue.get(timeout=5.0).data == "Hello, world!" + + +def test_change_message(node: ClientNode) -> None: + """Test that the LatchingPublisher updates the message correctly""" + publisher = LatchingPublisher(node, String, "test_topic", republish_delay=0.1) + initial_message = String(data="Initial message") + updated_message = String(data="Updated message") + + publisher(initial_message) + + # Validate that the initial message is published + assert node.msg_queue.get(timeout=5.0).data == "Initial message" + + # Update the message + publisher(updated_message) + + # Wait for the republish delay and validate the updated message + timeout = Timeout(seconds=25) + while timeout and node.msg_queue.get(timeout=5.0).data == "Initial message": + sleep(0.01) + + # The updated message should now be being published + assert node.msg_queue.get(timeout=5.0).data == "Updated message" + assert node.msg_queue.get(timeout=5.0).data == "Updated message" + assert node.msg_queue.get(timeout=5.0).data == "Updated message" diff --git a/pkgs/node_helpers/node_helpers_test/resources/__init__.py b/pkgs/node_helpers/node_helpers_test/resources/__init__.py new file mode 100644 index 0000000..91bc303 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/resources/__init__.py @@ -0,0 +1,7 @@ +from node_helpers.testing import resource_path + +_RESOURCE_DIR = resource_path("node_helpers_test/resources") +_URDFS_DIR = resource_path(_RESOURCE_DIR, "urdfs") + +# Path resources +GENERIC_URDF = resource_path(_URDFS_DIR, "robot.urdf") diff --git a/pkgs/node_helpers/node_helpers_test/resources/urdfs/robot.urdf b/pkgs/node_helpers/node_helpers_test/resources/urdfs/robot.urdf new file mode 100644 index 0000000..857ab13 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/resources/urdfs/robot.urdf @@ -0,0 +1,114 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/pkgs/node_helpers/node_helpers_test/unit/actions/test_server.py b/pkgs/node_helpers/node_helpers_test/unit/actions/test_server.py new file mode 100644 index 0000000..375a825 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/actions/test_server.py @@ -0,0 +1,164 @@ +from collections.abc import Generator +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest +from node_helpers.actions.server import ( + ActionTimeoutError, + ActionWorker, + NoResultSetError, +) +from rclpy.action.server import ServerGoalHandle + + +class ExampleGoal: + pass + + +class ExampleFeedback: + pass + + +@dataclass +class ExampleResult: + value: str + error_name: str = "" + error_description: str = "" + + +ExampleWorkerType = ActionWorker[ExampleGoal, ExampleFeedback, ExampleResult] + + +class ExampleException(Exception): + pass + + +class ExampleWorker(ExampleWorkerType): + """A basic action that produces feedback, responds to errors and cancellations, and + provides a result + """ + + FEEDBACK_COUNT = 3 + + def __init__(self, goal_handle: ServerGoalHandle): + super().__init__(goal_handle) + + self.canceled = False # This says a lot about our society + + def run(self) -> Generator[ExampleFeedback | None, None, None]: + for _ in range(self.FEEDBACK_COUNT): + yield ExampleFeedback() + self.result = ExampleResult("complete") + + def on_cancel(self) -> ExampleResult: + self.canceled = True + return ExampleResult("cancel") + + +def test_action_worker_normal_operation() -> None: + """Tests that a successful worker has its feedback and results published""" + + goal_handle = Mock() + goal_handle.is_cancel_requested = False + + worker = ExampleWorker(goal_handle) + result = worker.execute_callback(None) + + _assert_complete(worker, goal_handle, result) + assert goal_handle.publish_feedback.call_count == ExampleWorker.FEEDBACK_COUNT + + +def test_action_worker_canceled() -> None: + """Tests that a worker can be canceled, and that the worker's cancel callback is + called + """ + + goal_handle = Mock() + goal_handle.is_cancel_requested = True + + worker = ExampleWorker(goal_handle) + result = worker.execute_callback(None) + + _assert_canceled(worker, goal_handle, result) + + +class NoResultWorker(ExampleWorker): + """A worker implementation that forgot to set the result value after execution""" + + def run(self) -> Generator[ExampleFeedback | None, None, None]: + for _ in range(self.FEEDBACK_COUNT): + yield ExampleFeedback() + + +def test_action_worker_no_result() -> None: + """Tests that the worker errors out if subclasses fail to set the result attribute + after successful execution + """ + goal_handle = Mock() + goal_handle.is_cancel_requested = False + + worker = NoResultWorker(goal_handle) + + with pytest.raises(NoResultSetError): + worker.execute_callback(None) + + +class ForeverWorker(ExampleWorker): + """A worker implementation that will run forever unless stopped by a timeout""" + + def run(self) -> Generator[ExampleFeedback | None, None, None]: + while True: + yield ExampleFeedback() + + +def test_action_worker_timeout() -> None: + """Tests that long-running workers can be interrupted by a timeout""" + + goal_handle = Mock() + goal_handle.is_cancel_requested = False + + worker = ForeverWorker(goal_handle) + + with pytest.raises(ActionTimeoutError): + worker.execute_callback(1.0) + + +class NoneYieldingWorker(ExampleWorker): + """A worker that yields None instead of a feedback object""" + + def run(self) -> Generator[ExampleFeedback | None, None, None]: + for _ in range(self.FEEDBACK_COUNT): + yield None + self.result = ExampleResult("complete") + + +def test_action_worker_none_yielding() -> None: + """Tests that workers are allowed to yield None, and no feedback is provided when + they do + """ + goal_handle = Mock() + goal_handle.is_cancel_requested = False + + worker = NoneYieldingWorker(goal_handle) + result = worker.execute_callback(None) + + _assert_complete(worker, goal_handle, result) + goal_handle.publish_feedback.assert_not_called() + + +def _assert_complete( + worker: ExampleWorkerType, goal_handle: Mock, result: ExampleResult +) -> None: + assert isinstance(result, ExampleResult) + assert result.value == "complete" + goal_handle.succeed.assert_called_once() + assert worker.done + + +def _assert_canceled( + worker: ExampleWorkerType, goal_handle: Mock, result: ExampleResult +) -> None: + assert isinstance(result, ExampleResult) + assert result.value == "cancel" + goal_handle.canceled.assert_called_once() + assert worker.done diff --git a/pkgs/node_helpers/node_helpers_test/unit/futures/test_futures.py b/pkgs/node_helpers/node_helpers_test/unit/futures/test_futures.py new file mode 100644 index 0000000..abaeff1 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/futures/test_futures.py @@ -0,0 +1,67 @@ +from threading import Event + +import pytest +from node_helpers import futures +from node_helpers.testing import DynamicContextThread +from rclpy import Future + + +def test_wait_for_futures_waits_on_all_even_after_exception() -> None: + """Ensure that wait_for_futures will always wait for _all_ futures to complete + before raising the first exception that occured""" + + all_futures = [Future(), Future(), Future(), Future()] + finished = Event() + + def wait_on_futures() -> None: + try: + futures.wait_for_futures(all_futures, object) + except ValueError: + finished.set() + + with DynamicContextThread(wait_on_futures): + assert not finished.is_set() + + # Set the first two exceptions + all_futures[0].set_result(None) + all_futures[1].set_result(None) + + # Have the third end in an exception + all_futures[2].set_exception(ValueError) + + assert not finished.wait(0.1) + + # Finally, let the function continue + all_futures[3].set_result(None) + + assert finished.wait(0.1) + + +def test_yield_for_future_on_unfinished_future() -> None: + """Test yield_for_futures basic usage""" + future = Future() + expected_result = 3 + + generator = futures.yield_for_future(future, object, yield_interval=0) + for _ in range(3): + assert next(generator) is None + + # Now set the future, and ensure it returns the expected result + future.set_result(expected_result) + with pytest.raises(StopIteration) as error: + next(generator) + + assert error.value.value == expected_result + + +def test_yield_for_future_on_finished_future() -> None: + future = Future() + expected_value = 20 + future.set_result(expected_value) + + generator = futures.yield_for_future(future, object, yield_interval=0) + + # This should raise a StopIteration immediately, since the future is done + with pytest.raises(StopIteration) as error: + next(generator) + assert error.value.value == expected_value diff --git a/pkgs/node_helpers/node_helpers_test/unit/launching/test_files.py b/pkgs/node_helpers/node_helpers_test/unit/launching/test_files.py new file mode 100644 index 0000000..00449e0 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/launching/test_files.py @@ -0,0 +1,20 @@ +from tempfile import NamedTemporaryFile, TemporaryDirectory + +import pytest +from node_helpers.launching import required_directory, required_file + + +def test_required_file() -> None: + with pytest.raises(FileNotFoundError): + required_file("/i_dont_exist") + + with NamedTemporaryFile() as file_: + required_file(file_.name) + + +def test_required_directory() -> None: + with pytest.raises(FileNotFoundError): + required_directory("/seriously_this_dir_isnt_real") + + with TemporaryDirectory() as dir_: + required_directory(dir_) diff --git a/pkgs/node_helpers/node_helpers_test/unit/launching/test_swappable_nodes.py b/pkgs/node_helpers/node_helpers_test/unit/launching/test_swappable_nodes.py new file mode 100644 index 0000000..5d35e8e --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/launching/test_swappable_nodes.py @@ -0,0 +1,76 @@ +from typing import cast + +import pytest +from launch_ros.actions import Node +from node_helpers.launching import ( + InvalidSwapConfiguration, + SwapConfiguration, + SwappableNode, + apply_node_swaps, +) + +namespace_a = "namespace_a" +namespace_b = "namespace_b" +NODE_A = SwappableNode(namespace=namespace_a, name="node_a", executable="") +NODE_A_MOCK = SwappableNode(namespace=namespace_a, name="node_a_mock", executable="") +NODE_B = SwappableNode(namespace=namespace_b, name="node_b", executable="") +NODE_B_MOCK = SwappableNode(namespace=namespace_b, name="node_b_mock", executable="") +NORMAL_NODE_NAMESPACE_A = Node(namespace=namespace_a, name="cool_node", executable="") + + +@pytest.mark.parametrize( + ("config", "input_nodes", "expected_output"), + ( + # Test a mixed (some mocked some not) example + ( + { + namespace_a: SwapConfiguration( + mock="node_a_mock", real="node_a", enable_mock=False + ), + namespace_b: SwapConfiguration( + mock="node_b_mock", real="node_b", enable_mock=True + ), + }, + [NODE_A, NODE_A_MOCK, NODE_B, NODE_B_MOCK, NORMAL_NODE_NAMESPACE_A], + [NODE_A, NODE_B_MOCK, NORMAL_NODE_NAMESPACE_A], + ), + # Test no swaps or swap configuration + ({}, [NORMAL_NODE_NAMESPACE_A], [NORMAL_NODE_NAMESPACE_A]), + # Test no swap configuration but there are swaps in there + ({}, [NODE_A, NORMAL_NODE_NAMESPACE_A], InvalidSwapConfiguration), + # Test when a swap is enabled but its pair node is missing + ( + { + namespace_a: SwapConfiguration( + mock="node_a_mock", real="node_a", enable_mock=False + ) + }, + [NODE_A, NORMAL_NODE_NAMESPACE_A], + InvalidSwapConfiguration, + ), + # Test that invalid SwapConfigurations aren't allowed (a configuration where + # the mock name and real name are the same) + ( + { + namespace_a: SwapConfiguration( + mock="node_a_mock", real="node_a_mock", enable_mock=False + ) + }, + [NODE_A, NODE_A_MOCK], + InvalidSwapConfiguration, + ), + ), +) +def test_basic_usage( + config: dict[str, SwapConfiguration], + input_nodes: list[SwappableNode | Node], + expected_output: type[InvalidSwapConfiguration] | list[SwappableNode], +) -> None: + if expected_output is InvalidSwapConfiguration: + with pytest.raises(cast(type[Exception], expected_output)): + apply_node_swaps(configuration=config, launch_description=input_nodes) + else: + filtered = apply_node_swaps( + configuration=config, launch_description=input_nodes + ) + assert filtered == expected_output diff --git a/pkgs/node_helpers/node_helpers_test/unit/launching/test_urdf.py b/pkgs/node_helpers/node_helpers_test/unit/launching/test_urdf.py new file mode 100644 index 0000000..f0cd258 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/launching/test_urdf.py @@ -0,0 +1,112 @@ +import pytest +from node_helpers.launching import urdf +from node_helpers.launching.urdf import NAMESPACE_FMT +from node_helpers_test.resources import GENERIC_URDF + +EXPECTED_JOINT_NAMES = ["shuttle1-joint", "clamp1-joint"] +EXPECTED_LINK_NAMES = ["base_link", "shuttle1", "clamp1"] + + +def test_fix_urdf_paths_makes_path_replacements() -> None: + package_name = "node_helpers" + expected_pattern = f"{package_name}/{GENERIC_URDF.parent}/" + + modified_urdf = urdf.fix_urdf_paths( + package=package_name, relative_urdf_path=GENERIC_URDF + ) + original_urdf = GENERIC_URDF.read_text() + assert modified_urdf != original_urdf + + n_expected_changes = original_urdf.count("package://") + assert n_expected_changes == modified_urdf.count(expected_pattern) + assert original_urdf.count(expected_pattern) == 0 + + +def test_assert_joint_names_exist() -> None: + urdf_text = GENERIC_URDF.read_text() + + # Raises error when no names provided + with pytest.raises(ValueError): + urdf.assert_joint_names_exist([urdf_text], []) + + # Raises an error when invalid names provided + with pytest.raises(ValueError): + urdf.assert_joint_names_exist( + [urdf_text], [*EXPECTED_JOINT_NAMES, "nonexistent-joint"] + ) + + # No assertions raised when all joint names are valid + urdf.assert_joint_names_exist([urdf_text], EXPECTED_JOINT_NAMES) + + +def test_assert_attributes_exist_fails_on_duplicate_attributes_across_urdfs() -> None: + """The basic contract with multi-urdf loading is that each urdf has unique frame + or joint names; at least, the ones that are being asserted. This test validates + that check is being done as expected.""" + + urdf_text = GENERIC_URDF.read_text() + urdf_text_namespaced = urdf.prepend_namespace(GENERIC_URDF.read_text(), "cool") + + _expected_links_namespaced = [ + NAMESPACE_FMT.format(namespace="cool", name=n) for n in EXPECTED_LINK_NAMES + ] + + # Raises error when two urdfs have the same joint name + with pytest.raises(ValueError): + urdf.assert_link_names_exist([urdf_text, urdf_text], EXPECTED_LINK_NAMES) + + # This should fail because the joint names don't actually exist + with pytest.raises(ValueError): + urdf.assert_link_names_exist( + [urdf_text_namespaced, urdf_text_namespaced], EXPECTED_LINK_NAMES + ) + + # Test happy path with just one urdf file + urdf.assert_link_names_exist([urdf_text_namespaced], _expected_links_namespaced) + urdf.assert_link_names_exist([urdf_text], EXPECTED_LINK_NAMES) + + # Test happy paths where the two urdfs are namespaced, and the joints all exist + urdf.assert_link_names_exist( + [urdf_text, urdf_text_namespaced], + [*_expected_links_namespaced, *EXPECTED_LINK_NAMES], + ) + urdf.assert_link_names_exist([urdf_text, urdf_text_namespaced], EXPECTED_LINK_NAMES) + urdf.assert_link_names_exist( + [urdf_text, urdf_text_namespaced], _expected_links_namespaced + ) + + +def test_assert_link_names_exist() -> None: + urdf_text = GENERIC_URDF.read_text() + + # Raises error when no names provided + with pytest.raises(ValueError): + urdf.assert_link_names_exist([urdf_text], []) + + # Raises an error when invalid names provided + with pytest.raises(ValueError): + urdf.assert_link_names_exist( + [urdf_text], [*EXPECTED_LINK_NAMES, "nonexistent-link"] + ) + + # No assertions raised when all link names are valid + urdf.assert_link_names_exist([urdf_text], EXPECTED_LINK_NAMES) + + +def test_prepend_namespace() -> None: + urdf_text: str = GENERIC_URDF.read_text() + namespace = "cool_namespace" + modified = urdf.prepend_namespace(urdf_text, namespace=namespace) + + for changes in (EXPECTED_JOINT_NAMES, EXPECTED_LINK_NAMES): + expected_changes = list(map(urdf_text.count, changes)) + actual_changes = list( + map( + modified.count, + [ + urdf.NAMESPACE_FMT.format(namespace=namespace, name=j) + for j in changes + ], + ) + ) + assert len(expected_changes) == len(actual_changes) diff --git a/pkgs/node_helpers/node_helpers_test/unit/markers/test_interactive_marker.py b/pkgs/node_helpers/node_helpers_test/unit/markers/test_interactive_marker.py new file mode 100644 index 0000000..b3538ba --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/markers/test_interactive_marker.py @@ -0,0 +1,29 @@ +from node_helpers.markers import InteractiveVolumeMarker + + +def test_init_configuration() -> None: + """Just basic testing that the object is built correctly""" + name = "cool_name" + frame_id = "cool_frame" + scale = [1.0, 2.0, 3.0] + description = "cool description" + + marker = InteractiveVolumeMarker( + name=name, frame_id=frame_id, scale=scale, description=description, fixed=False + ) + assert marker.interactive_marker.name == name + assert marker.interactive_marker.description == description + + # The frame_id should be passed to both the interactive_marker and it's box + assert marker.interactive_marker.header.frame_id == frame_id + assert marker.box.header.frame_id == frame_id + + # The scale parameter should refer to the underlying box + box_scale = [marker.box.scale.x, marker.box.scale.y, marker.box.scale.z] + assert box_scale == scale + + marker_scale = marker.interactive_marker.scale + assert marker_scale == 3 * 2.25 + assert ( + len(marker.interactive_marker.controls) == 7 + ), "There should be 7 controllers, 1 for the box, and 6 since it's a 6 dof marker!" diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/__init__.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/test_choosable_object.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_choosable_object.py new file mode 100644 index 0000000..d33b0e8 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_choosable_object.py @@ -0,0 +1,362 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator + +import pytest +from node_helpers.parameters import Choosable, choosable_object +from node_helpers.parameters.choosable_object import ( + ClassRegistryType, + DuplicateRegistrationError, + InstanceRegistryType, + UnregisteredChoosableError, +) +from pydantic import BaseModel + +from .test_parameters_mixin_declaring_pydantic_models import ParameterNode + + +@pytest.fixture() +def clean_registries() -> ( + Generator[tuple[ClassRegistryType, InstanceRegistryType], None, None] +): + """Return an unmodified global class and instance registry before each test""" + # Copy the contents of the registries before the test + class_registry = choosable_object._global_choosable_class_registry + instance_registry = choosable_object._global_choosable_instance_registry + classes_before = class_registry.copy() + instances_before = instance_registry.copy() + + # Clear the registries before and after each test + class_registry.clear() + instance_registry.clear() + yield class_registry, instance_registry + class_registry.clear() + instance_registry.clear() + + # Re add items to the class and instance registries + class_registry.update(classes_before) + instance_registry.update(instances_before) + + +def test_pydantic_parsing_integration_for_classes( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test that classes can be chosen via configuration, using parameter: type[Class] + in a pydantic model. + """ + + class BaseSomeClassThatIsChoosable(Choosable, ABC): + """Test that the base class pattern still works with choosable classes""" + + @abstractmethod + def some_method(self) -> str: + pass + + assert len(choosable_object._global_choosable_class_registry) == 1 + + class SomeChoosableClassImplementation(BaseSomeClassThatIsChoosable): + def some_method(self) -> str: + return "hey this shouldn't be called" + + assert len(choosable_object._global_choosable_class_registry) == 1 + + class AnotherChoosableClassImplementation(BaseSomeClassThatIsChoosable): + def some_method(self) -> str: + return "correct!" + + # Validate the registry now has these classes + assert choosable_object._global_choosable_class_registry == { + BaseSomeClassThatIsChoosable: { + "BaseSomeClassThatIsChoosable": BaseSomeClassThatIsChoosable, + "SomeChoosableClassImplementation": SomeChoosableClassImplementation, + "AnotherChoosableClassImplementation": AnotherChoosableClassImplementation, + } + } + + class ParametersWithChoosableClass(BaseModel): + choosable_class_name: type[BaseSomeClassThatIsChoosable] + + node = ParameterNode() + node.config_values["a.choosable_class_name"] = "AnotherChoosableClassImplementation" + + model = node.declare_from_pydantic_model( + ParametersWithChoosableClass, "a", subscribe_to_updates=False + ) + + # Validate that the declared values use strings, but the class was correctly chosen + assert ( + node.declared["a.choosable_class_name"][0] + == "AnotherChoosableClassImplementation" + ) + assert model.choosable_class_name is AnotherChoosableClassImplementation + assert model.choosable_class_name().some_method() == "correct!" + + +def test_pydantic_parsing_integration_for_instances( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test that (registered) instances of a class can be chosen via configuration, + using parameter: Class in a pydantic model. + """ + + class BaseChoosableInstance(Choosable): + pass + + class InheritorChoosableInstance(BaseChoosableInstance): + pass + + class ParametersWithChoosableInstance(BaseModel): + choosable_instance: BaseChoosableInstance # note it's not wrapped in a type[] + + # Register a few instances + instance_1 = BaseChoosableInstance() + instance_2 = InheritorChoosableInstance() + instance_3 = BaseChoosableInstance() + instance_1.register_instance("instance_1") + instance_2.register_instance("instance_2") + instance_3.register_instance("instance_3") + + node = ParameterNode() + node.config_values["a.choosable_instance"] = "instance_2" + + model = node.declare_from_pydantic_model( + ParametersWithChoosableInstance, "a", subscribe_to_updates=False + ) + + # Validate the chosen instance is was retrieved from the global registry based on + # the users configuration values + assert model.choosable_instance is instance_2 + assert node.declared["a.choosable_instance"][0] == "instance_2" + + # Now try declaring an instance that doesn't exist + node.config_values["a.choosable_instance"] = "instance_4" + with pytest.raises(UnregisteredChoosableError): + model = node.declare_from_pydantic_model(ParametersWithChoosableInstance, "a") + + +def test_descriptive_error_is_returned_when_class_not_registered( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + class BaseChoosable(Choosable): + pass + + class RegisteredClassA(BaseChoosable): + pass + + class RegisteredClassB(BaseChoosable): + pass + + assert ( + BaseChoosable.get_registered_child_class("RegisteredClassA") is RegisteredClassA + ) + assert ( + BaseChoosable.get_registered_child_class("RegisteredClassB") is RegisteredClassB + ) + + # This class isn't defined anywhere, so it should raise an error + with pytest.raises(UnregisteredChoosableError): + BaseChoosable.get_registered_child_class("RegisteredClassC") + + # This class is defined but only under the child class, so it should raise an error + with pytest.raises(RuntimeError): + Choosable.get_registered_child_class("RegisteredClassA") + with pytest.raises(RuntimeError): + Choosable.get_registered_child_class("RegisteredClassB") + + +def test_registration_happens_with_basest_choosable_class( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test that if there's a chain of subclassing, the 'root' inheritor of + Choosable is the one that gets registered + """ + + class BaseChoosable(Choosable): + pass + + class RegisteredClassA(BaseChoosable): + pass + + class RegisteredClassB(RegisteredClassA): + pass + + assert ( + BaseChoosable.get_registered_child_class("RegisteredClassA") is RegisteredClassA + ) + + # Try accessing RegisteredClassB from the base class and the A class (both work) + assert ( + BaseChoosable.get_registered_child_class("RegisteredClassB") is RegisteredClassB + ) + assert ( + RegisteredClassB.get_registered_child_class("RegisteredClassB") + is RegisteredClassB + ) + assert ( + RegisteredClassA.get_registered_child_class("RegisteredClassB") + is RegisteredClassB + ) + + +def test_custom_registered_name( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test that if there's a chain of subclassing, the 'root' inheritor of + Choosable is the one that gets registered + """ + + class BaseChoosable(Choosable): + pass + + class CustomNameClass(BaseChoosable, registered_name="custom_name"): + pass + + class DefaultNameClass(BaseChoosable): + pass + + assert BaseChoosable.get_registered_child_class("custom_name") is CustomNameClass + assert ( + BaseChoosable.get_registered_child_class("DefaultNameClass") is DefaultNameClass + ) + + +def test_instance_registration_is_scoped( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test that instances can be registered and retrieved, using scoping rules""" + + class BaseA(Choosable): + pass + + class ImplA(BaseA): + pass + + class BaseB(Choosable): + pass + + class ImplB(BaseB): + pass + + # Shotgun approach to registering instances. The real check here is to ensure that + # the instances are registered in the correct scope, and that both base and child + # classes can reach the 'scoped registry' for their base class + instance_a_1 = ImplA() + instance_a_2 = ImplA() # noqa: F841 + instance_a_3 = BaseA() + instance_b_1 = ImplB() + instance_b_2 = ImplB() # noqa: F841 + instance_b_3 = BaseB() + + class_registry, instance_registry = clean_registries + assert len(class_registry) == 2 + assert len(instance_registry) == 0 + + # Add two instances of A, and validate that they live under the same scope + instance_a_1.register_instance("test_a_1") + assert len(instance_registry) == 1 + instance_a_3.register_instance("test_a_3") + assert len(instance_registry) == 1 + assert len(class_registry) == 2 + assert BaseA.get_registered_instance("test_a_1") is instance_a_1 + assert BaseA.get_registered_instance("test_a_3") is instance_a_3 + assert instance_a_3.get_registered_instance("test_a_1") is instance_a_1 + assert instance_a_3.get_registered_instance("test_a_3") is instance_a_3 + + # Now validate that if a B instance is created, it lives under a new scope + instance_b_1.register_instance("test_b_1") + instance_b_3.register_instance("test_b_3") + assert len(instance_registry) == 2 + assert BaseB.get_registered_instance("test_b_1") is instance_b_1 + assert ImplB.get_registered_instance("test_b_3") is instance_b_3 + + # Now validate that failures occur if you try to access the wrong scope + with pytest.raises(UnregisteredChoosableError): + BaseA.get_registered_instance("test_b_1") + + with pytest.raises(UnregisteredChoosableError): + ImplB.get_registered_instance("test_a_1") + + +def test_reusing_the_same_name_fails_for_classes( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test you aren't allowed to register two classes with the same name""" + + class BaseA(Choosable): + pass + + # This is okay, because two root 'scopes' with the same name can exist + class BaseA(Choosable): # type: ignore # noqa: F811 + pass + + class ImplA(BaseA): + pass + + # This however is not okay, because two classes with the same name and shared scope + # should not exist + with pytest.raises(DuplicateRegistrationError): + + class ImplA(BaseA): # type: ignore # noqa: F811 + pass + + # This is fine however, because a custom name is used + class ImplA(BaseA, registered_name="custom_name"): # type: ignore # noqa: F811 + pass + + +def test_reusing_same_name_fails_for_instances( + clean_registries: tuple[ClassRegistryType, InstanceRegistryType], +) -> None: + """Test you aren't allowed to register two instances with the same name""" + + class Base(Choosable): + pass + + class SomeInstance(Base): + pass + + instance_1 = SomeInstance() + instance_2 = SomeInstance() + instance_3 = Base() + + instance_1.register_instance("instance_1") + + with pytest.raises(DuplicateRegistrationError): + instance_1.register_instance("instance_1") + with pytest.raises(DuplicateRegistrationError): + instance_2.register_instance("instance_1") + with pytest.raises(DuplicateRegistrationError): + instance_3.register_instance("instance_1") + + instance_2.register_instance("instance_2") + instance_3.register_instance("instance_3") + + +def test_getting_registered_name_for_class() -> None: + class Base(Choosable, registered_name="custom_base_name"): + pass + + class Impl(Base): + pass + + assert Base.get_registered_class_name() == "custom_base_name" + assert Impl.get_registered_class_name() == "Impl" + + +def test_getting_registered_name_for_instance() -> None: + class Base(Choosable): + pass + + instance = Base() + instance.register_instance("custom_instance_name") + + assert instance.get_registered_instance_name() == "custom_instance_name" + assert Base.get_registered_class_name() == "Base" + + +def test_getting_registered_name_for_unregistered_instance() -> None: + class Base(Choosable): + pass + + unregistered_instance = Base() + with pytest.raises(UnregisteredChoosableError): + unregistered_instance.get_registered_instance_name() diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/test_loading.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_loading.py new file mode 100644 index 0000000..c8990f4 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_loading.py @@ -0,0 +1,238 @@ +import contextlib +from collections.abc import Generator +from copy import deepcopy +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, cast + +import pytest +import yaml +from node_helpers.parameters import ( + FIELD_PLACEHOLDER, + Namespace, + ParameterLoader, + loading, +) +from pydantic import BaseModel +from rclpy.parameter import Parameter + + +def test_namespace_with_item_copies() -> None: + """Tests that Namespace.with_item creates an independent copy of its items""" + namespace1 = Namespace(["manipulators", "bomb_defuser"]) + namespace2 = namespace1.with_item("red_wire") + + namespace1.items.append("blue_wire") + + assert namespace1.items == ["manipulators", "bomb_defuser", "blue_wire"] + assert namespace2.items == ["manipulators", "bomb_defuser", "red_wire"] + + +def test_parameter_loader_base_loading() -> None: + """Tests loading of a single parameter file""" + with _parameter_directory(_BASE_PARAMETERS_YAML) as parameter_dir: + loader = ParameterLoader( + parameters_directory=parameter_dir, + meta_parameters_schema=MetaParameters, + ) + + assert loader.parameters == _EXPECTED_PARAMETERS + assert loader.meta_parameters == _EXPECTED_META_PARAMETERS + + +def test_parameter_loader_override_loading() -> None: + """Tests that parameters and meta parameters are combined as expected when loading + of base and override parameter file + """ + with ( + _parameter_directory(_BASE_PARAMETERS_YAML) as parameter_dir, + _parameter_file(_OVERRIDE_PARAMETERS_YAML) as override_path, + ): + loader = ParameterLoader( + parameters_directory=parameter_dir, + override_file=override_path, + meta_parameters_schema=MetaParameters, + ) + + # Validate ros parameters + expected_parameters = cast( + dict[Namespace, dict[Namespace, Any]], deepcopy(_EXPECTED_PARAMETERS) + ) + expected_parameters[Namespace(["nearby_theater"])][ + Namespace(["incident_report", "description"]) + ] = "How wrong I was! Friendship is a beautiful thing." + expected_parameters[Namespace(["nearby_theater"])][ + Namespace(["lessons_learned"]) + ] = True + + assert loader.parameters == expected_parameters + + # Validate meta parameters + expected_meta_parameters = deepcopy(_EXPECTED_META_PARAMETERS.model_dump()) + expected_meta_parameters["hangouts"] = ["detention"] + assert loader.meta_parameters.model_dump() == expected_meta_parameters + + +def test_parameter_loader_file_saving() -> None: + """Tests that saving the resulting canonical parameter file has the expected data. + It also shouldn't include any meta parameters data.""" + with _parameter_directory(_BASE_PARAMETERS_YAML) as parameter_dir: + loader: ParameterLoader[BaseModel] = ParameterLoader( + parameters_directory=parameter_dir + ) + + loaded_text = yaml.full_load(loader.ros_parameters_file.read_text()) + base_parameters = yaml.full_load(_BASE_PARAMETERS_YAML) + del base_parameters[loading._META_PARAMETERS_KEY] + assert loaded_text == base_parameters + + +def test_parameter_loader_for_node() -> None: + """Tests getting parameters for a single node from the loader""" + with _parameter_directory(_BASE_PARAMETERS_YAML) as parameter_dir: + loader: ParameterLoader[BaseModel] = ParameterLoader( + parameters_directory=parameter_dir + ) + + params = loader.parameters_for_node(Namespace(["high_school_movie", "cool_kid"])) + assert len(params) == 1 + assert params[0].name == "is_cool" + assert params[0].type_ == Parameter.Type.BOOL + assert params[0].value + + +def test_parameter_loader_loads_without_meta_parameters_schema_provided() -> None: + """Test that not providing a meta parameters schema is okay, but you can't access + the meta_parameters property without an error being raised.""" + with _parameter_directory(_BASE_PARAMETERS_YAML) as parameter_dir: + loader: ParameterLoader[BaseModel] = ParameterLoader( + parameters_directory=parameter_dir + ) + + # The class shouldn't let you load meta parameters when a schema wasn't initially + # provided. + with pytest.raises(RuntimeError): + loader.meta_parameters # noqa: B018 + assert loader._meta_parameters is None + + +def test_parameter_loader_invalid_meta_parameters() -> None: + """Test that an exception is raised when a meta parameters schema is given but + the meta parameters key isn't in the yaml.""" + with _parameter_directory(_NO_META_PARAMETERS_YAML) as parameter_dir: + # This should work + ParameterLoader(parameters_directory=parameter_dir) + + # This should fail + with pytest.raises(loading.ParameterLoadingError): + ParameterLoader( + parameters_directory=parameter_dir, + meta_parameters_schema=MetaParameters, + ) + + +def test_merge_dictionaries_deepcopies() -> None: + """Make sure that merge_dictionary isn't editing in place.""" + val_1 = object() + val_2 = object() + val_2_replacement = object() + val_3 = object() + + a = {"key_1": val_1, "key_2": val_2} + b = {"key_2": val_2_replacement, "key_3": val_3} + + merged = ParameterLoader._merge_dictionaries(a, b) + + assert ["key_1", "key_2", "key_3"] == list(merged.keys()) + for val in merged.values(): + assert val not in (val_1, val_2, val_3, val_2_replacement) + + +@contextlib.contextmanager +def _parameter_file(value: str) -> Generator[Path, None, None]: + with TemporaryDirectory() as temp_dir: + temp_file = Path(temp_dir) / "some-params.yaml" + with temp_file.open("w") as f: + f.write(value) + f.flush() + + yield Path(temp_file) + + +@contextlib.contextmanager +def _parameter_directory(value: str) -> Generator[Path, None, None]: + with _parameter_file(value) as f: + yield f.parent + + +class MetaParameters(BaseModel): + school_name: str + cool_kid_to_dork_ratio: float + hangouts: list[str] + + +_BASE_PARAMETERS_YAML = """ +meta_parameters: + school_name: Rydell High School + cool_kid_to_dork_ratio: 0.9 + hangouts: ["corner store", "under the bleachers"] + +high_school_movie: + cool_kid: + ros__parameters: + is_cool: true + dork_kid: + ros__parameters: + catchphrase: This math stuff is easy + +nearby_theater: + ros__parameters: + employees: + - cool_kid + - dork_kid + incident_report: + type: Endearing Friendship + description: Jocks can't hang with nerds! +""" + +_EXPECTED_META_PARAMETERS = MetaParameters( + school_name="Rydell High School", + cool_kid_to_dork_ratio=0.9, + hangouts=["corner store", "under the bleachers"], +) +_EXPECTED_PARAMETERS = { + Namespace(["high_school_movie", "cool_kid"]): { + Namespace(["is_cool"]): True, + }, + Namespace(["high_school_movie", "dork_kid"]): { + Namespace(["catchphrase"]): "This math stuff is easy", + }, + Namespace(["nearby_theater"]): { + Namespace(["employees"]): ["cool_kid", "dork_kid"], + Namespace(["incident_report", "type"]): "Endearing Friendship", + Namespace(["incident_report", "description"]): "Jocks can't hang with nerds!", + }, +} + +_OVERRIDE_PARAMETERS_YAML = """ +meta_parameters: + hangouts: ["detention"] + +nearby_theater: + ros__parameters: + incident_report: + description: How wrong I was! Friendship is a beautiful thing. + lessons_learned: true +""" + +_INCOMPLETE_PARAMETERS_YAML = f""" +some_node: + ros__parameters: + some_param: {FIELD_PLACEHOLDER} +""" + +_NO_META_PARAMETERS_YAML = """ +some_node: + ros__parameters: + some_param: 3 +""" diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_pydantic_models.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_pydantic_models.py new file mode 100644 index 0000000..deaa626 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_pydantic_models.py @@ -0,0 +1,238 @@ +from enum import Enum +from pathlib import Path +from typing import Any + +import pytest +from node_helpers.parameters import ParameterMixin, RequiredParameterNotSetException +from pydantic import BaseModel, Field +from rcl_interfaces.msg import ParameterDescriptor +from rclpy import Parameter + + +class ParameterNode(ParameterMixin): + def __init__(self) -> None: + self.declared: dict[str, tuple[Any, ParameterDescriptor]] = {} + """Keeps track of declared parameters, for use in assertions.""" + + self.config_values: dict[str, Any] = {} + """If you want declare_parameter to return a particular value for a parameter + name. This emulating how ROS might return the value of a parameter as defined + by a parameter file.""" + + def declare_parameter( + self, + name: str, + value: Any, + descriptor: ParameterDescriptor, + ignore_override: bool, + ) -> Parameter: + """Mock the Node.declare_parameter function""" + assert ( + ignore_override is False + ), "The ParameterMixin is expected not to ignore parameter overrides" + assert isinstance(descriptor, ParameterDescriptor) + + # Keep track of declared parameters, to help testing + self.declared[name] = (self.config_values.get(name, value), descriptor) + + # Return the value as ROS would after having pulled it from configuration + return Parameter(name=name, value=self.declared[name][0]) + + def get_namespace(self) -> str: + """Used for descriptive logs""" + return "i'm not really used in these tests yay!" + + +def test_default_values() -> None: + """Test that default values work""" + + class Model(BaseModel): + attr_with_default: str = "ayy" + attr_without_default: int + + node = ParameterNode() + with pytest.raises(RequiredParameterNotSetException): + model = node.declare_from_pydantic_model( + Model, "cool_prefix", subscribe_to_updates=False + ) + + # Then after setting the value in the "configuration file" + node.config_values["cool_prefix.attr_without_default"] = 3 + model = node.declare_from_pydantic_model( + Model, "cool_prefix", subscribe_to_updates=False + ) + assert model.attr_without_default == 3 + + +def test_field_descriptions() -> None: + """Test that docstrings are correctly passed into the parameter description""" + with_description_description_value = "This is a cool value" + + class Model(BaseModel): + no_description: str = "ayy" + with_description: int = Field( + description=with_description_description_value, default=3 + ) + + node = ParameterNode() + node.declare_from_pydantic_model(Model, "prefix", subscribe_to_updates=False) + + assert ( + node.declared["prefix.with_description"][1].description + == with_description_description_value + ) + assert node.declared["prefix.no_description"][1].description == "" + + +def test_prefix_is_added() -> None: + """Test that empty prefixes are not allowed for pydantic models""" + + expected_value = [1, 2, 3, 4] + + class Model(BaseModel): + attr_a: list[int] + + node = ParameterNode() + node.config_values["cool_prefix.attr_a"] = expected_value + + # Try without prefix + with pytest.raises(ValueError): + node.declare_from_pydantic_model(Model, "", subscribe_to_updates=False) + + # Try with prefix + model = node.declare_from_pydantic_model( + Model, "cool_prefix", subscribe_to_updates=False + ) + assert node.declared["cool_prefix.attr_a"][0] == expected_value + assert model.attr_a == expected_value + + +def test_complex_types() -> None: + """Test that all ROS types are supported by this system. + This validates that the parameter mixin isn't just using Pydantic, but also uses + typing.get_type_hints() behind the scenes. + + Without get_type_hints usage, this test used to fail because pydantic had a more + complicated way of handling types like List[Type]. + """ + + # These are all supported types, as seen in the CONFIG_TO_ROS_MAPPING + class ManyTypes(BaseModel): + attr_0: bool = True + attr_1: int = 3 + attr_2: float = 4.0 + attr_3: str = "" + attr_4: list[bytes] = [b"1", b"2", b"3"] + attr_5: list[bool] = [True, False] + attr_6: list[int] = [1, 2, 3] + attr_7: list[float] = [1.0, 2.0, 3.0] + attr_8: list[str] = ["a", "b", "cde"] + + node = ParameterNode() + assert isinstance(ManyTypes().attr_4, list | tuple) + assert all(isinstance(v, bytes) for v in ManyTypes().attr_4) + + # The real test is that this call doesn't fail + node.declare_from_pydantic_model(ManyTypes, "prefix", subscribe_to_updates=False) + + # This is to feel assured that the defaults were pulled and not changed in any way. + assert node.declared["prefix.attr_0"][0] == ManyTypes().attr_0 + assert node.declared["prefix.attr_1"][0] == ManyTypes().attr_1 + assert node.declared["prefix.attr_2"][0] == ManyTypes().attr_2 + assert node.declared["prefix.attr_3"][0] == ManyTypes().attr_3 + assert node.declared["prefix.attr_4"][0] == ManyTypes().attr_4 + assert node.declared["prefix.attr_5"][0] == ManyTypes().attr_5 + assert node.declared["prefix.attr_6"][0] == ManyTypes().attr_6 + assert node.declared["prefix.attr_7"][0] == ManyTypes().attr_7 + assert node.declared["prefix.attr_8"][0] == ManyTypes().attr_8 + + +def test_nonros_arbitrary_types_are_wrapped() -> None: + """Test that arbitrary types (e.g. non-ros types) are parsed as ints then wrapped""" + + class ExampleEnum(Enum): + OPTION_1 = "option_1" + OPTION_2 = "option_2" + + class ArbitraryTypeModel(BaseModel): + some_path: Path + some_enum: ExampleEnum + + # Validates that we can handle Nonetype when it's set as the default value + some_nonetype: None = None + + node = ParameterNode() + node.config_values["prefix.some_path"] = "/path/to/something" + node.config_values["prefix.some_enum"] = "option_2" + + model = node.declare_from_pydantic_model( + ArbitraryTypeModel, "prefix", subscribe_to_updates=False + ) + + # Validate that the declared values use strings, but the model values are wrapped + assert node.declared["prefix.some_enum"][0] == "option_2" + assert node.declared["prefix.some_path"][0] == "/path/to/something" + assert node.declared["prefix.some_nonetype"][0] is None + assert model.some_enum is ExampleEnum.OPTION_2 + assert model.some_path == Path("/path/to/something") + assert model.some_nonetype is None + + +def test_nested_models() -> None: + """Tests that we can nest models in other models, and that their parameters are + prefixed properly + """ + + class ParentModel(BaseModel): + class ChildModel(BaseModel): + child_attr: str + + parent_attr: str + child: ChildModel + + parent_value = "I am a parent" + child_value = "I'm just a wittle baybee" # Sorry + + node = ParameterNode() + node.config_values["prefix.parent_attr"] = parent_value + node.config_values["prefix.child.child_attr"] = child_value + + node.declare_from_pydantic_model(ParentModel, "prefix", subscribe_to_updates=False) + assert len(node.declared) == 2 + assert node.declared["prefix.parent_attr"][0] == parent_value + assert node.declared["prefix.child.child_attr"][0] == child_value + + +def test_union_types_basic_use() -> None: + class ExampleEnum(Enum): + OPTION_1 = "option_1" + OPTION_2 = "option_2" + + class ModelWithUnions(BaseModel): + bool_or_int: bool | int + int_or_str: int | str + might_be_anything: int | str | list[int] | list[str] | bool + path_or_int: Path | int + int_or_enum: int | ExampleEnum + int_or_enum_with_default: int | ExampleEnum = ExampleEnum.OPTION_1 + list_int_or_none: list[int] | None = None + + node = ParameterNode() + node.config_values["prefix.might_be_anything"] = [1, 2, 3, 4] + node.config_values["prefix.bool_or_int"] = True + node.config_values["prefix.int_or_str"] = "ayy" + node.config_values["prefix.path_or_int"] = "/path/to/something" + node.config_values["prefix.int_or_enum"] = "option_2" + + model = node.declare_from_pydantic_model( + ModelWithUnions, "prefix", subscribe_to_updates=False + ) + + # Validate that the model is correctly declared + assert model.might_be_anything == [1, 2, 3, 4] + assert model.bool_or_int is True + assert model.int_or_str == "ayy" + assert model.path_or_int == Path("/path/to/something") + assert model.int_or_enum is ExampleEnum.OPTION_2 + assert model.int_or_enum_with_default is ExampleEnum.OPTION_1 + assert model.list_int_or_none is None diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_values.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_values.py new file mode 100644 index 0000000..1ef63d6 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_parameters_mixin_declaring_values.py @@ -0,0 +1,92 @@ +from unittest import mock + +import pytest +from node_helpers.parameters import ( + FIELD_PLACEHOLDER, + ParameterMixin, + RequiredParameterNotSetException, + UnfilledParametersFileError, +) + + +def test_required_disallows_default_value() -> None: + """Test that setting a default value is not allowed when required=True""" + with pytest.raises(ValueError): + ParameterMixin().declare_and_get_parameter( + name="test", + required=True, + type_=str, + default_value="Uh oh, this shouldn't be allowed!", + ) + + +def test_default_value_doesnt_match_assigned_type() -> None: + """Test that if the default_value doesn't match the documented parameter_type, + a TypeError is raised.""" + with pytest.raises(TypeError): + ParameterMixin().declare_and_get_parameter( + name="coolparam", default_value=3.0, type_=str + ) + + +def test_required_raises_error_if_not_set() -> None: + """Test that if required=True and the value is not set, that an error is + raised. + """ + + class CoolNode(mock.MagicMock, ParameterMixin): + pass + + # Mock a node that will be pre-filled with an incorrect type + node = CoolNode() + parameter_mock = mock.MagicMock() + parameter_mock.value = None + node.declare_parameter.return_value = parameter_mock + + with pytest.raises(RequiredParameterNotSetException): + node.declare_and_get_parameter( + name="test", + required=True, + type_=float, + ) + + +def test_incorrect_type() -> None: + """Test that if a user described the parameter as being a certain type, but then + configures it with another type, that an error is raised.""" + + class CoolNode(mock.MagicMock, ParameterMixin): + pass + + # Mock a node that will be pre-filled with an incorrect type + node = CoolNode() + parameter_mock = mock.MagicMock() + parameter_mock.value = "This is a string type!!!" + node.declare_parameter.return_value = parameter_mock + + with pytest.raises(TypeError): + node.declare_and_get_parameter( + name="test", + required=True, + type_=float, + ) + + +def test_unfilled_detection() -> None: + """Test that fields with a sentinel placeholder value cause an error""" + + class CoolNode(mock.MagicMock, ParameterMixin): + pass + + # Mock a node that will be pre-filled with a placeholder value + node = CoolNode() + parameter_mock = mock.MagicMock() + parameter_mock.value = FIELD_PLACEHOLDER + node.declare_parameter.return_value = parameter_mock + + with pytest.raises(UnfilledParametersFileError): + node.declare_and_get_parameter( + name="test", + required=True, + type_=str, + ) diff --git a/pkgs/node_helpers/node_helpers_test/unit/parameters/test_path.py b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_path.py new file mode 100644 index 0000000..1657f9f --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/parameters/test_path.py @@ -0,0 +1,6 @@ +from node_helpers.parameters import param_path + + +def test_basic_operation() -> None: + result = param_path("pkgs/node_helpers/node_helpers/__init__.py") + assert result.is_file() diff --git a/pkgs/node_helpers/node_helpers_test/unit/pubsub/test_topic.py b/pkgs/node_helpers/node_helpers_test/unit/pubsub/test_topic.py new file mode 100644 index 0000000..b17bbf4 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/pubsub/test_topic.py @@ -0,0 +1,164 @@ +import gc +from queue import Queue + +import pytest +from node_helpers.pubsub import DuplicateSubscriberError, SubscriberNotFoundError, Topic + + +def test_publishing_subscribing() -> None: + """A basic test of Publisher/Subscriber functionality""" + topic = Topic[str]() + queue_1: Queue[str] = Queue() + + # Smoke test + topic.publish("no one should get this!") + assert queue_1.qsize() == 0 + topic.subscribe(queue_1.put) + topic.publish("topic_1_content") + assert queue_1.qsize() == 1 + assert queue_1.get() == "topic_1_content" + + +def test_gc_listener_gets_removed() -> None: + """A test to make sure that if a listener is garbage collected, it is + removed from the Queue after the next time publish is called""" + topic = Topic[str]() + queue_1: Queue[str] = Queue() + + # Subscribe the listener, queue_1.put, and verify there is a weak ref + assert len(topic._subscribers) == 0 + topic.subscribe(queue_1.put) + assert len(topic._subscribers) == 1 + + # Publish some content + topic.publish("alex is cool") + assert queue_1.get() == "alex is cool" + + # Delete the listener and garbage collect + del queue_1 + gc.collect() + + # Test the listener is deleted the next time something is published + assert len(topic._subscribers) == 1 + topic.publish("i am a string look at me") + assert len(topic._subscribers) == 0 + + +def test_removing_listener() -> None: + """Test removing subscribers from a topic works""" + topic = Topic[str]() + queue_1: Queue[str] = Queue() + queue_2: Queue[str] = Queue() + + # Send a method to two channels + topic.subscribe(queue_1.put) + topic.subscribe(queue_2.put) + topic.publish("to_both") + assert queue_1.get() == "to_both" + assert queue_2.get() == "to_both" + + # Remove a listener, and make sure it gets nothing + topic.unsubscribe(queue_1.put) + topic.publish("to_queue_2") + assert queue_2.get() == "to_queue_2" + assert queue_1.qsize() == 0 + + # Test removing an already removed listener raises an error + with pytest.raises(SubscriberNotFoundError): + topic.unsubscribe(queue_1.put) + + +def test_no_duplicate_subscribers() -> None: + topic = Topic[None]() + queue: Queue[None] = Queue() + + topic.subscribe(queue.put) + + # This should not be okay + with pytest.raises(DuplicateSubscriberError): + topic.subscribe(queue.put) + + # A second (new) subscriber should be okay + second_queue: Queue[None] = Queue() + topic.subscribe(second_queue.put) + + +def test_subscribe_call_counts() -> None: + """Tests that subscribers are notified when a topic's value changes""" + topic = Topic[int]() + + subscriber1_call_count = 0 + + def subscriber1(value: int) -> None: + assert value == 5 + nonlocal subscriber1_call_count + subscriber1_call_count += 1 + + topic.subscribe(subscriber1) + + subscriber2_call_count = 0 + + def subscriber2(value: int) -> None: + assert value == 5 + nonlocal subscriber2_call_count + subscriber2_call_count += 1 + + topic.subscribe(subscriber2) + + topic.publish(5) + + assert subscriber1_call_count == 1 + assert subscriber2_call_count == 1 + + +def test_publish_no_subscribers() -> None: + """Tests that no errors occur when publishing to a topic with no + subscribers. + """ + topic = Topic[int]() + topic.publish(8) + + +def test_subscribe_as_event() -> None: + """Tests that event subscribers are notified of a publish""" + topic = Topic[str]() + + event = topic.subscribe_as_event() + assert not event.is_set(), "The event was set before a publish" + + # Tests that the event subscriber receives the publish + topic.publish("Hello event!", "You should be set") + assert event.is_set(), "The event was not set after a publish" + event.wait_and_clear() + + assert not event.is_set() + + +def test_gc_subscribe_as_event_removed() -> None: + """Tests that the event subscription is removed when the event is garbage + collected. + """ + topic = Topic[str]() + + event = topic.subscribe_as_event() + assert len(topic._subscribers) == 1 + + del event + gc.collect() + + assert len(topic._subscribers) == 1 + topic.publish("What is good") + assert len(topic._subscribers) == 0 + + +def test_multiple_subscribe_as_event() -> None: + """Tests that all subscribed events are notified with a single publish""" + topic = Topic[str]() + + event1 = topic.subscribe_as_event() + event2 = topic.subscribe_as_event() + + topic.publish("What's up, my events!") + + assert event1.is_set() + assert event2.is_set() diff --git a/pkgs/node_helpers/node_helpers_test/unit/robust_rpc/test_errors.py b/pkgs/node_helpers/node_helpers_test/unit/robust_rpc/test_errors.py new file mode 100644 index 0000000..c6f6c0d --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/robust_rpc/test_errors.py @@ -0,0 +1,57 @@ +import pytest +from node_helpers.robust_rpc import RobustRPCException + + +class CoolError(Exception): + pass + + +def test_caching() -> None: + assert RobustRPCException.like("CoolError") is RobustRPCException.like(CoolError) + assert CoolError is not RobustRPCException.like(CoolError) # type: ignore + assert RobustRPCException.like("CoolError").__name__ == CoolError.__name__ + + +def test_routing_exceptions_works() -> None: + """This test is a sanity check of how RobustRPCExceptions are supposed to work""" + try: + raise RobustRPCException.like("CoolError")( + error_name="CoolError", error_description="", message=None + ) + except RobustRPCException.like("NotCoolError"): + raise + except RobustRPCException.like(CoolError): + routed_correctly = True + except RobustRPCException: + raise + + assert routed_correctly + + +def test_routing_accepts_robust_rpc_exception() -> None: + """Test that the 'like' function returns an object that subclasses + RobustRPCException + """ + try: + raise RobustRPCException.like("CoolError")( + error_name="CoolError", error_description="", message=None + ) + except RobustRPCException.like("NotCoolError"): + raise + except RobustRPCException: + routed_correctly = True + + assert routed_correctly + + +def test_routing_falls_through() -> None: + """A simple sanity check that one exception doesn't get matched with another of a + different name""" + + with pytest.raises(RobustRPCException.like("CoolError")): + try: + raise RobustRPCException.like("CoolError")( + error_name="CoolError", error_description="", message=None + ) + except RobustRPCException.like("NotCoolError"): + pass diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_geometry.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_geometry.py new file mode 100644 index 0000000..02d224b --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_geometry.py @@ -0,0 +1,93 @@ +import node_helpers.ros2_numpy as rnp +import numpy as np +import pytest +import tf_transformations as transformations +from geometry_msgs.msg import Point, Pose, Quaternion, Transform, Vector3 + + +def test_point() -> None: + p = Point(x=1.0, y=2.0, z=3.0) + + p_arr = rnp.numpify(p) + np.testing.assert_array_equal(p_arr, [1, 2, 3]) + + p_arrh = rnp.numpify(p, hom=True) + np.testing.assert_array_equal(p_arrh, [1, 2, 3, 1]) + + assert p == rnp.msgify(Point, p_arr) + assert p == rnp.msgify(Point, p_arrh) + assert p == rnp.msgify(Point, p_arrh * 2) + + +def test_vector3() -> None: + v = Vector3(x=1.0, y=2.0, z=3.0) + + v_arr = rnp.numpify(v) + np.testing.assert_array_equal(v_arr, [1, 2, 3]) + + v_arrh = rnp.numpify(v, hom=True) + np.testing.assert_array_equal(v_arrh, [1, 2, 3, 0]) + + assert v == rnp.msgify(Vector3, v_arr) + assert v == rnp.msgify(Vector3, v_arrh) + + with pytest.raises(AssertionError): + rnp.msgify(Vector3, np.array([0, 0, 0, 1])) + + +def test_transform() -> None: + t = Transform( + translation=Vector3(x=1.0, y=2.0, z=3.0), + rotation=Quaternion( + **dict( + zip( + ["x", "y", "z", "w"], + transformations.quaternion_from_euler(np.pi, 0, 0), + strict=False, + ) + ) + ), + ) + + t_mat = rnp.numpify(t) + + np.testing.assert_allclose(t_mat.dot([0, 0, 1, 1]), [1.0, 2.0, 2.0, 1.0]) + + msg = rnp.msgify(Transform, t_mat) + + np.testing.assert_allclose(msg.translation.x, t.translation.x) + np.testing.assert_allclose(msg.translation.y, t.translation.y) + np.testing.assert_allclose(msg.translation.z, t.translation.z) + np.testing.assert_allclose(msg.rotation.x, t.rotation.x) + np.testing.assert_allclose(msg.rotation.y, t.rotation.y) + np.testing.assert_allclose(msg.rotation.z, t.rotation.z) + np.testing.assert_allclose(msg.rotation.w, t.rotation.w) + + +def test_pose() -> None: + t = Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion( + **dict( + zip( + ["x", "y", "z", "w"], + transformations.quaternion_from_euler(np.pi, 0, 0), + strict=False, + ) + ) + ), + ) + + t_mat = rnp.numpify(t) + + np.testing.assert_allclose(t_mat.dot([0, 0, 1, 1]), [1.0, 2.0, 2.0, 1.0]) + + msg = rnp.msgify(Pose, t_mat) + + np.testing.assert_allclose(msg.position.x, t.position.x) + np.testing.assert_allclose(msg.position.y, t.position.y) + np.testing.assert_allclose(msg.position.z, t.position.z) + np.testing.assert_allclose(msg.orientation.x, t.orientation.x) + np.testing.assert_allclose(msg.orientation.y, t.orientation.y) + np.testing.assert_allclose(msg.orientation.z, t.orientation.z) + np.testing.assert_allclose(msg.orientation.w, t.orientation.w) diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_images.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_images.py new file mode 100644 index 0000000..824b018 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_images.py @@ -0,0 +1,55 @@ +import node_helpers.ros2_numpy as rnp +import numpy as np +import pytest +from sensor_msgs.msg import Image + + +def test_roundtrip_rgb8() -> None: + arr = np.random.randint(0, 256, size=(240, 360, 3)).astype(np.uint8) + msg = rnp.msgify(Image, arr, encoding="rgb8") + arr2 = rnp.numpify(msg) + + np.testing.assert_equal(arr, arr2) + + +def test_roundtrip_mono() -> None: + arr = np.random.randint(0, 256, size=(240, 360)).astype(np.uint8) + msg = rnp.msgify(Image, arr, encoding="mono8") + arr2 = rnp.numpify(msg) + + np.testing.assert_equal(arr, arr2) + + +def test_roundtrip_big_endian() -> None: + arr = np.random.randint(0, 256, size=(240, 360)).astype(">u2") + msg = rnp.msgify(Image, arr, encoding="mono16") + assert msg.is_bigendian + arr2 = rnp.numpify(msg) + + np.testing.assert_equal(arr, arr2) + + +def test_roundtrip_little_endian() -> None: + arr = np.random.randint(0, 256, size=(240, 360)).astype(" None: + mono_arr = np.random.randint(0, 256, size=(240, 360)).astype(np.uint8) + mono_arrf = np.random.randint(0, 256, size=(240, 360)).astype(np.float32) + rgb_arr = np.random.randint(0, 256, size=(240, 360, 3)).astype(np.uint8) + rgb_arrf = np.random.randint(0, 256, size=(240, 360, 3)).astype(np.float32) + + with pytest.raises(TypeError): + rnp.msgify(Image, rgb_arr, encoding="mono8") + with pytest.raises(TypeError): + rnp.msgify(Image, mono_arrf, encoding="mono8") + + with pytest.raises(TypeError): + rnp.msgify(Image, rgb_arrf, encoding="rgb8") + with pytest.raises(TypeError): + rnp.msgify(Image, mono_arr, encoding="rgb8") diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_laserscan.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_laserscan.py new file mode 100644 index 0000000..2f75f1c --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_laserscan.py @@ -0,0 +1,53 @@ +import node_helpers.ros2_numpy as rnp +import numpy as np +from rclpy.clock import Clock +from sensor_msgs.msg import LaserScan + + +def test_to_and_from_laserscan() -> None: + # Create a dummy LaserScan message + scan = LaserScan() + scan.header.stamp = Clock().now().to_msg() + scan.header.frame_id = "lidar_2d" + scan.range_min = 0.01 + scan.range_max = 200.0 + scan.angle_increment = np.radians(0.1) + scan.angle_min = -np.pi + scan.angle_max = np.pi - scan.angle_increment + scan.scan_time = 0.0 + scan.time_increment = 0.0 + scan.ranges = np.full(3600, 10.0, dtype="f4").tolist() + scan.intensities = np.full(3600, 5, dtype="f4").tolist() + laserscan_array = rnp.numpify( + scan, + remove_invalid_ranges=False, + include_ranges_and_intensities=True, + ) + + assert laserscan_array.shape[0], 3600 + np.testing.assert_array_equal(laserscan_array["ranges"], np.array(scan.ranges)) + np.testing.assert_array_equal( + laserscan_array["intensities"], np.array(scan.intensities) + ) + + laserscan_array_without_ranges_and_intensities = rnp.numpify(scan) + + laserscan_msg = rnp.msgify( + LaserScan, + laserscan_array_without_ranges_and_intensities, + scan.header, + scan.scan_time, + scan.time_increment, + ) + + assert np.isclose(scan.angle_min, laserscan_msg.angle_min) + assert np.isclose(laserscan_msg.angle_increment, scan.angle_increment) + assert np.isclose(laserscan_msg.angle_max, scan.angle_max) + assert np.isclose(laserscan_msg.scan_time, scan.scan_time) + assert np.isclose(laserscan_msg.time_increment, scan.time_increment) + + assert laserscan_msg.header.frame_id == scan.header.frame_id + + assert len(laserscan_msg.ranges) == 3600 + assert len(laserscan_msg.intensities) == 3600 + np.testing.assert_array_equal(laserscan_array["ranges"], np.array(scan.ranges)) diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_occupancygrids.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_occupancygrids.py new file mode 100644 index 0000000..4e5b4e4 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_occupancygrids.py @@ -0,0 +1,33 @@ +import node_helpers.ros2_numpy as rnp +import numpy as np +from nav_msgs.msg import MapMetaData, OccupancyGrid +from rclpy.serialization import serialize_message + + +def test_masking() -> None: + data = -np.ones((30, 30), np.int8) + data[10:20, 10:20] = 100 + + msg = rnp.msgify(OccupancyGrid, data) + + data_out = rnp.numpify(msg) + + assert data_out[5, 5] is np.ma.masked + np.testing.assert_equal(data_out[10:20, 10:20], 100) + + +def test_serialization() -> None: + msg = OccupancyGrid( + info=MapMetaData(width=3, height=3), data=[0, 0, 0, 0, -1, 0, 0, 0, 0] + ) + + data = rnp.numpify(msg) + assert data[1, 1] is np.ma.masked + msg2 = rnp.msgify(OccupancyGrid, data) + + assert msg.info == msg2.info + + msg_ser = serialize_message(msg) + msg2_ser = serialize_message(msg2) + + assert msg_ser == msg2_ser, "Message serialization survives round-trip" diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_pointclouds.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_pointclouds.py new file mode 100644 index 0000000..537b140 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_pointclouds.py @@ -0,0 +1,82 @@ +import node_helpers.ros2_numpy as rnp +import numpy as np +import numpy.typing as npt +from sensor_msgs.msg import PointCloud2, PointField + + +def make_array(npoints: int) -> npt.NDArray[np.float64]: + points_arr = np.zeros( + (npoints,), + dtype=[ + ("x", np.float32), + ("y", np.float32), + ("z", np.float32), + ("r", np.uint8), + ("g", np.uint8), + ("b", np.uint8), + ], + ) + points_arr["x"] = np.random.random((npoints,)) + points_arr["y"] = np.random.random((npoints,)) + points_arr["z"] = np.random.random((npoints,)) + points_arr["r"] = np.floor(np.random.random((npoints,)) * 255) + points_arr["g"] = 0 + points_arr["b"] = 255 + + return points_arr + + +def test_convert_dtype() -> None: + fields = [ + PointField(name="x", offset=0, count=1, datatype=PointField.FLOAT32), + PointField(name="y", offset=4, count=1, datatype=PointField.FLOAT32), + ] + dtype = np.dtype([("x", np.float32), ("y", np.float32)]) + conv_fields = rnp.msgify(PointField, dtype, plural=True) + assert fields == conv_fields, "dtype->Pointfield Failed with simple values" + + conv_dtype = rnp.numpify(fields, point_step=8) + assert dtype == conv_dtype, "dtype->Pointfield Failed with simple values" + + +def test_convert_dtype_inner() -> None: + fields = [ + PointField(name="x", offset=0, count=1, datatype=PointField.FLOAT32), + PointField(name="y", offset=4, count=1, datatype=PointField.FLOAT32), + PointField(name="vectors", offset=8, count=3, datatype=PointField.FLOAT32), + ] + + dtype = np.dtype( + [("x", np.float32), ("y", np.float32), ("vectors", np.float32, (3,))] + ) + + conv_fields = rnp.msgify(PointField, dtype, plural=True) + assert fields == conv_fields, "dtype->Pointfield with inner dimensions" + + conv_dtype = rnp.numpify(fields, point_step=8) + assert dtype == conv_dtype, "Pointfield->dtype with inner dimensions" + + +def test_roundtrip() -> None: + points_arr = make_array(100) + cloud_msg = rnp.msgify(PointCloud2, points_arr) + new_points_arr = rnp.numpify(cloud_msg) + + np.testing.assert_equal(points_arr, new_points_arr) + + +def test_roundtrip_numpy() -> None: + points_arr = make_array(100) + cloud_msg = rnp.msgify(PointCloud2, points_arr) + new_points_arr = rnp.numpify(cloud_msg) + + np.testing.assert_equal(points_arr, new_points_arr) + + +def test_roundtrip_zero_points() -> None: + """Test to make sure zero point arrays don't raise memoryview.cast(*) errors""" + points_arr = make_array(0) + cloud_msg = rnp.msgify(PointCloud2, points_arr) + new_points_arr = rnp.numpify(cloud_msg) + + np.testing.assert_equal(points_arr, new_points_arr) diff --git a/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_quat.py b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_quat.py new file mode 100644 index 0000000..1f69544 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/ros2_numpy/test_quat.py @@ -0,0 +1,14 @@ +import geometry_msgs +import node_helpers.ros2_numpy as rnp +import numpy as np +import tf_transformations as trans + + +def test_representation() -> None: + q = trans.quaternion_from_euler(0.0, 0.0, 0.0) + assert np.allclose(q, np.array([0.0, 0.0, 0.0, 1.0])) + + +def test_identity_transform() -> None: + h = rnp.numpify(geometry_msgs.msg.Transform()) + assert np.allclose(h, np.eye(4)) diff --git a/pkgs/node_helpers/node_helpers_test/unit/tf/test_timestamps.py b/pkgs/node_helpers/node_helpers_test/unit/tf/test_timestamps.py new file mode 100644 index 0000000..7d06293 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/tf/test_timestamps.py @@ -0,0 +1,37 @@ +import math + +from builtin_interfaces.msg import Time +from node_helpers.tf import timestamps + + +def test_timestamp_conversions() -> None: + """Test the timestamp conversion functions""" + # A python time.time() timestamp + python_timestamp = 1702493856.576949412 + + ros2_timestamp = timestamps.unix_timestamp_to_ros(python_timestamp) + assert ros2_timestamp.sec == 1702493856 + assert math.isclose(ros2_timestamp.nanosec, 576949412, abs_tol=100) + + python_timestamp_reconverted = timestamps.ros_stamp_to_unix_timestamp( + ros2_timestamp + ) + assert python_timestamp_reconverted == python_timestamp + + +def test_is_newer() -> None: + """Test the is_newer function""" + time_a = Time(sec=1702493856, nanosec=576949412) + time_b = Time(sec=1702493855, nanosec=576949412) + + assert timestamps.is_newer(time_a, time_b) is True + assert timestamps.is_newer(time_b, time_a) is False + + +def test_is_older() -> None: + """Test the is_older function""" + time_a = Time(sec=1702493856, nanosec=576949412) + time_b = Time(sec=1702493855, nanosec=576949412) + + assert timestamps.is_older(time_a, time_b) is False + assert timestamps.is_older(time_b, time_a) is True diff --git a/pkgs/node_helpers/node_helpers_test/unit/timing/test_timeout.py b/pkgs/node_helpers/node_helpers_test/unit/timing/test_timeout.py new file mode 100644 index 0000000..58dc858 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/timing/test_timeout.py @@ -0,0 +1,53 @@ +import pytest +from node_helpers.timing import Timeout, Timer + + +def test_timeout_happy_path() -> None: + """Test when a timer doesn't time out""" + timeout = Timeout(seconds=999) + + assert timeout, "Non timed out timers should return True!" + + +def test_timeout_unhappy_path() -> None: + """Test when a timer times out""" + + timeout = Timeout(seconds=0.015, raise_error=True) + timer = Timer() + with pytest.raises(TimeoutError), timer: + while timeout: + pass + assert 0.02 > timer.elapsed > 0.01 + + +def test_timeout_no_error() -> None: + """Test when a timer times out, but an error should not be raised""" + + timeout = Timeout(seconds=0.015, raise_error=False) + timer = Timer() + with timer: + while timeout: + pass + assert 0.02 > timer.elapsed > 0.01 + assert not timeout + + +def test_timeout_reset() -> None: + """Test resetting with a different timeout""" + + timeout = Timeout(seconds=0.015) + timer = Timer() + with timer: + while timeout: + pass + assert 0.02 > timer.elapsed > 0.01 + assert not timeout + + timeout.reset_seconds(0.01) + + timer = Timer() + with timer: + while timeout: + pass + assert not timeout + assert 0.015 > timer.elapsed > 0.005 diff --git a/pkgs/node_helpers/node_helpers_test/unit/timing/test_timer.py b/pkgs/node_helpers/node_helpers_test/unit/timing/test_timer.py new file mode 100644 index 0000000..e587d28 --- /dev/null +++ b/pkgs/node_helpers/node_helpers_test/unit/timing/test_timer.py @@ -0,0 +1,110 @@ +import pytest +from node_helpers.timing import Timer + + +def test_context_manager() -> None: + """Test the context manager of the timer utility works as expected""" + n_samples = 5 + timer = Timer(samples=n_samples) + + for i in range(n_samples): + with timer: + pass + assert len(timer._samples) == i + 1 + + with timer: + pass + + # Verify the timer doesn't keep samples over the specified amount + assert len(timer._samples) == n_samples + assert timer.fps > 0 + assert timer.elapsed > 0 + + +def test_decorator() -> None: + """Test the decorator functionality of the timer utility works as expected""" + timer = Timer(samples=3) + + @timer + def my_cool_func() -> int: + return 3 + + retval = my_cool_func() + assert retval == 3 + assert len(timer._samples) == 1 + + my_cool_func() + assert len(timer._samples) == 2 + + +def test_printing_doesnt_cause_zero_division_error() -> None: + timer = Timer(samples=100) + assert len(timer._samples) == 0 + assert repr(timer) == f"Timer({Timer._NO_SAMPLES_MSG})" + + with timer: + pass + + assert len(timer._samples) == 1 + assert repr(timer) != f"Timer({Timer._NO_SAMPLES_MSG})" + + +def test_child_gets_parameters_passed() -> None: + timer = Timer(samples=32) + child_name = "cool-functionality" + assert len(timer._children) == 0 + + # Create the child + with timer: + timer.child(child_name) + + assert len(timer._children) == 1 + assert timer._children[child_name].name == child_name + assert timer._children[child_name]._num_samples == timer._num_samples + + # Add samples to the child + with timer, timer.child(child_name): + pass + + assert len(timer._children[child_name]._samples) == 1 + + +def test_report_generation() -> None: + name = "cool_name" + timer = Timer(samples=42, name=name) + assert name.title() in str(timer) + assert Timer._NO_SAMPLES_MSG in str(timer) + assert Timer._REPORT_INDENT not in str(timer) + + with timer: + pass + + assert name.title() in str(timer) + assert Timer._REPORT_INDENT not in str(timer) + + # Test children + with timer, timer.child("thing"): + pass + assert str(timer).count(Timer._REPORT_INDENT) == 1 + + +def test_enter_child_outside_of_context() -> None: + timer = Timer() + timer_name = "cool-child" + + with pytest.raises(RuntimeError): + timer.child(timer_name) + + assert len(timer._children) == 0 + + +def test_zero_samples_behavior() -> None: + """The timer shouldn't throw errors when 'elapsed' and 'fps' are called""" + timer = Timer() + + assert timer.elapsed == 0.0 + assert timer.total_elapsed == 0.0 + assert timer.fps == 0.0 + assert isinstance(timer.elapsed, float) + assert isinstance(timer.total_elapsed, float) + assert isinstance(timer.fps, float) diff --git a/pkgs/node_helpers/package.xml b/pkgs/node_helpers/package.xml index c401f80..3207f4e 100644 --- a/pkgs/node_helpers/package.xml +++ b/pkgs/node_helpers/package.xml @@ -2,19 +2,26 @@ node_helpers - 1.0.0 - An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. Features intuitive APIs for parameter management, action handling, and error-resilient RPC. Designed by Urban Machine for safe and scalable robotics. + 0.5.0 + An opinionated ROS2 framework that minimizes boilerplate while + maximizing reliability. Features intuitive APIs for parameter management, action + handling, and error-resilient RPC. Designed by Urban Machine for safe and + scalable robotics. + urbanmachine MIT python3-pytest python3-pytest-cov - - - robot_state_publisher + rviz2 - joint_state_publisher + + + geometry_msgs + nav_msgs + sensor_msgs + tf_transformations ament_python diff --git a/pkgs/node_helpers/poetry.lock b/pkgs/node_helpers/poetry.lock index f89b8fc..4648f75 100644 --- a/pkgs/node_helpers/poetry.lock +++ b/pkgs/node_helpers/poetry.lock @@ -1,7 +1,299 @@ # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. -package = [] + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "numpy" +version = "2.1.3" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "numpy-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff"}, + {file = "numpy-2.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5"}, + {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1"}, + {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd"}, + {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3"}, + {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098"}, + {file = "numpy-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c"}, + {file = "numpy-2.1.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4"}, + {file = "numpy-2.1.3-cp310-cp310-win32.whl", hash = "sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23"}, + {file = "numpy-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0"}, + {file = "numpy-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d"}, + {file = "numpy-2.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41"}, + {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9"}, + {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09"}, + {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a"}, + {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b"}, + {file = "numpy-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee"}, + {file = "numpy-2.1.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0"}, + {file = "numpy-2.1.3-cp311-cp311-win32.whl", hash = "sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9"}, + {file = "numpy-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2"}, + {file = "numpy-2.1.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e"}, + {file = "numpy-2.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958"}, + {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8"}, + {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564"}, + {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512"}, + {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b"}, + {file = "numpy-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc"}, + {file = "numpy-2.1.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0"}, + {file = "numpy-2.1.3-cp312-cp312-win32.whl", hash = "sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9"}, + {file = "numpy-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a"}, + {file = "numpy-2.1.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f"}, + {file = "numpy-2.1.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598"}, + {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57"}, + {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe"}, + {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43"}, + {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56"}, + {file = "numpy-2.1.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a"}, + {file = "numpy-2.1.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef"}, + {file = "numpy-2.1.3-cp313-cp313-win32.whl", hash = "sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f"}, + {file = "numpy-2.1.3-cp313-cp313-win_amd64.whl", hash = "sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed"}, + {file = "numpy-2.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f"}, + {file = "numpy-2.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4"}, + {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e"}, + {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0"}, + {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408"}, + {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6"}, + {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f"}, + {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17"}, + {file = "numpy-2.1.3-cp313-cp313t-win32.whl", hash = "sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48"}, + {file = "numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4"}, + {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f"}, + {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4"}, + {file = "numpy-2.1.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d"}, + {file = "numpy-2.1.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb"}, + {file = "numpy-2.1.3.tar.gz", hash = "sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761"}, +] + +[[package]] +name = "pydantic" +version = "2.10.2" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"}, + {file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"}, +] + +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.27.1" +typing-extensions = ">=4.12.2" + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] + +[[package]] +name = "pydantic-core" +version = "2.27.1" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.27.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:71a5e35c75c021aaf400ac048dacc855f000bdfed91614b4a726f7432f1f3d6a"}, + {file = "pydantic_core-2.27.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f82d068a2d6ecfc6e054726080af69a6764a10015467d7d7b9f66d6ed5afa23b"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:121ceb0e822f79163dd4699e4c54f5ad38b157084d97b34de8b232bcaad70278"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4603137322c18eaf2e06a4495f426aa8d8388940f3c457e7548145011bb68e05"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a33cd6ad9017bbeaa9ed78a2e0752c5e250eafb9534f308e7a5f7849b0b1bfb4"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15cc53a3179ba0fcefe1e3ae50beb2784dede4003ad2dfd24f81bba4b23a454f"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45d9c5eb9273aa50999ad6adc6be5e0ecea7e09dbd0d31bd0c65a55a2592ca08"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8bf7b66ce12a2ac52d16f776b31d16d91033150266eb796967a7e4621707e4f6"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:655d7dd86f26cb15ce8a431036f66ce0318648f8853d709b4167786ec2fa4807"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:5556470f1a2157031e676f776c2bc20acd34c1990ca5f7e56f1ebf938b9ab57c"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f69ed81ab24d5a3bd93861c8c4436f54afdf8e8cc421562b0c7504cf3be58206"}, + {file = "pydantic_core-2.27.1-cp310-none-win32.whl", hash = "sha256:f5a823165e6d04ccea61a9f0576f345f8ce40ed533013580e087bd4d7442b52c"}, + {file = "pydantic_core-2.27.1-cp310-none-win_amd64.whl", hash = "sha256:57866a76e0b3823e0b56692d1a0bf722bffb324839bb5b7226a7dbd6c9a40b17"}, + {file = "pydantic_core-2.27.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac3b20653bdbe160febbea8aa6c079d3df19310d50ac314911ed8cc4eb7f8cb8"}, + {file = "pydantic_core-2.27.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a5a8e19d7c707c4cadb8c18f5f60c843052ae83c20fa7d44f41594c644a1d330"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f7059ca8d64fea7f238994c97d91f75965216bcbe5f695bb44f354893f11d52"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed0f8a0eeea9fb72937ba118f9db0cb7e90773462af7962d382445f3005e5a4"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3cb37038123447cf0f3ea4c74751f6a9d7afef0eb71aa07bf5f652b5e6a132c"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84286494f6c5d05243456e04223d5a9417d7f443c3b76065e75001beb26f88de"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc07b2cfc5b835444b44a9956846b578d27beeacd4b52e45489e93276241025"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fefee876e07a6e9aad7a8c8c9f85b0cdbe7df52b8a9552307b09050f7512c7e"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:258c57abf1188926c774a4c94dd29237e77eda19462e5bb901d88adcab6af919"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:35c14ac45fcfdf7167ca76cc80b2001205a8d5d16d80524e13508371fb8cdd9c"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d1b26e1dff225c31897696cab7d4f0a315d4c0d9e8666dbffdb28216f3b17fdc"}, + {file = "pydantic_core-2.27.1-cp311-none-win32.whl", hash = "sha256:2cdf7d86886bc6982354862204ae3b2f7f96f21a3eb0ba5ca0ac42c7b38598b9"}, + {file = "pydantic_core-2.27.1-cp311-none-win_amd64.whl", hash = "sha256:3af385b0cee8df3746c3f406f38bcbfdc9041b5c2d5ce3e5fc6637256e60bbc5"}, + {file = "pydantic_core-2.27.1-cp311-none-win_arm64.whl", hash = "sha256:81f2ec23ddc1b476ff96563f2e8d723830b06dceae348ce02914a37cb4e74b89"}, + {file = "pydantic_core-2.27.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9cbd94fc661d2bab2bc702cddd2d3370bbdcc4cd0f8f57488a81bcce90c7a54f"}, + {file = "pydantic_core-2.27.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f8c4718cd44ec1580e180cb739713ecda2bdee1341084c1467802a417fe0f02"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15aae984e46de8d376df515f00450d1522077254ef6b7ce189b38ecee7c9677c"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ba5e3963344ff25fc8c40da90f44b0afca8cfd89d12964feb79ac1411a260ac"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:992cea5f4f3b29d6b4f7f1726ed8ee46c8331c6b4eed6db5b40134c6fe1768bb"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0325336f348dbee6550d129b1627cb8f5351a9dc91aad141ffb96d4937bd9529"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7597c07fbd11515f654d6ece3d0e4e5093edc30a436c63142d9a4b8e22f19c35"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bbd5d8cc692616d5ef6fbbbd50dbec142c7e6ad9beb66b78a96e9c16729b089"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:dc61505e73298a84a2f317255fcc72b710b72980f3a1f670447a21efc88f8381"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:e1f735dc43da318cad19b4173dd1ffce1d84aafd6c9b782b3abc04a0d5a6f5bb"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f4e5658dbffe8843a0f12366a4c2d1c316dbe09bb4dfbdc9d2d9cd6031de8aae"}, + {file = "pydantic_core-2.27.1-cp312-none-win32.whl", hash = "sha256:672ebbe820bb37988c4d136eca2652ee114992d5d41c7e4858cdd90ea94ffe5c"}, + {file = "pydantic_core-2.27.1-cp312-none-win_amd64.whl", hash = "sha256:66ff044fd0bb1768688aecbe28b6190f6e799349221fb0de0e6f4048eca14c16"}, + {file = "pydantic_core-2.27.1-cp312-none-win_arm64.whl", hash = "sha256:9a3b0793b1bbfd4146304e23d90045f2a9b5fd5823aa682665fbdaf2a6c28f3e"}, + {file = "pydantic_core-2.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f216dbce0e60e4d03e0c4353c7023b202d95cbaeff12e5fd2e82ea0a66905073"}, + {file = "pydantic_core-2.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a2e02889071850bbfd36b56fd6bc98945e23670773bc7a76657e90e6b6603c08"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b0e23f119b2b456d07ca91b307ae167cc3f6c846a7b169fca5326e32fdc6cf"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:764be71193f87d460a03f1f7385a82e226639732214b402f9aa61f0d025f0737"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c00666a3bd2f84920a4e94434f5974d7bbc57e461318d6bb34ce9cdbbc1f6b2"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ccaa88b24eebc0f849ce0a4d09e8a408ec5a94afff395eb69baf868f5183107"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c65af9088ac534313e1963443d0ec360bb2b9cba6c2909478d22c2e363d98a51"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:206b5cf6f0c513baffaeae7bd817717140770c74528f3e4c3e1cec7871ddd61a"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:062f60e512fc7fff8b8a9d680ff0ddaaef0193dba9fa83e679c0c5f5fbd018bc"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:a0697803ed7d4af5e4c1adf1670af078f8fcab7a86350e969f454daf598c4960"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:58ca98a950171f3151c603aeea9303ef6c235f692fe555e883591103da709b23"}, + {file = "pydantic_core-2.27.1-cp313-none-win32.whl", hash = "sha256:8065914ff79f7eab1599bd80406681f0ad08f8e47c880f17b416c9f8f7a26d05"}, + {file = "pydantic_core-2.27.1-cp313-none-win_amd64.whl", hash = "sha256:ba630d5e3db74c79300d9a5bdaaf6200172b107f263c98a0539eeecb857b2337"}, + {file = "pydantic_core-2.27.1-cp313-none-win_arm64.whl", hash = "sha256:45cf8588c066860b623cd11c4ba687f8d7175d5f7ef65f7129df8a394c502de5"}, + {file = "pydantic_core-2.27.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5897bec80a09b4084aee23f9b73a9477a46c3304ad1d2d07acca19723fb1de62"}, + {file = "pydantic_core-2.27.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0165ab2914379bd56908c02294ed8405c252250668ebcb438a55494c69f44ab"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b9af86e1d8e4cfc82c2022bfaa6f459381a50b94a29e95dcdda8442d6d83864"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f6c8a66741c5f5447e047ab0ba7a1c61d1e95580d64bce852e3df1f895c4067"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a42d6a8156ff78981f8aa56eb6394114e0dedb217cf8b729f438f643608cbcd"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64c65f40b4cd8b0e049a8edde07e38b476da7e3aaebe63287c899d2cff253fa5"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdcf339322a3fae5cbd504edcefddd5a50d9ee00d968696846f089b4432cf78"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf99c8404f008750c846cb4ac4667b798a9f7de673ff719d705d9b2d6de49c5f"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f1edcea27918d748c7e5e4d917297b2a0ab80cad10f86631e488b7cddf76a36"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:159cac0a3d096f79ab6a44d77a961917219707e2a130739c64d4dd46281f5c2a"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:029d9757eb621cc6e1848fa0b0310310de7301057f623985698ed7ebb014391b"}, + {file = "pydantic_core-2.27.1-cp38-none-win32.whl", hash = "sha256:a28af0695a45f7060e6f9b7092558a928a28553366519f64083c63a44f70e618"}, + {file = "pydantic_core-2.27.1-cp38-none-win_amd64.whl", hash = "sha256:2d4567c850905d5eaaed2f7a404e61012a51caf288292e016360aa2b96ff38d4"}, + {file = "pydantic_core-2.27.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e9386266798d64eeb19dd3677051f5705bf873e98e15897ddb7d76f477131967"}, + {file = "pydantic_core-2.27.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4228b5b646caa73f119b1ae756216b59cc6e2267201c27d3912b592c5e323b60"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3dfe500de26c52abe0477dde16192ac39c98f05bf2d80e76102d394bd13854"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aee66be87825cdf72ac64cb03ad4c15ffef4143dbf5c113f64a5ff4f81477bf9"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b748c44bb9f53031c8cbc99a8a061bc181c1000c60a30f55393b6e9c45cc5bd"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ca038c7f6a0afd0b2448941b6ef9d5e1949e999f9e5517692eb6da58e9d44be"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bd57539da59a3e4671b90a502da9a28c72322a4f17866ba3ac63a82c4498e"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ac6c2c45c847bbf8f91930d88716a0fb924b51e0c6dad329b793d670ec5db792"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b94d4ba43739bbe8b0ce4262bcc3b7b9f31459ad120fb595627eaeb7f9b9ca01"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:00e6424f4b26fe82d44577b4c842d7df97c20be6439e8e685d0d715feceb9fb9"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:38de0a70160dd97540335b7ad3a74571b24f1dc3ed33f815f0880682e6880131"}, + {file = "pydantic_core-2.27.1-cp39-none-win32.whl", hash = "sha256:7ccebf51efc61634f6c2344da73e366c75e735960b5654b63d7e6f69a5885fa3"}, + {file = "pydantic_core-2.27.1-cp39-none-win_amd64.whl", hash = "sha256:a57847b090d7892f123726202b7daa20df6694cbd583b67a592e856bff603d6c"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3fa80ac2bd5856580e242dbc202db873c60a01b20309c8319b5c5986fbe53ce6"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d950caa237bb1954f1b8c9227b5065ba6875ac9771bb8ec790d956a699b78676"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e4216e64d203e39c62df627aa882f02a2438d18a5f21d7f721621f7a5d3611d"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a3d637bd387c41d46b002f0e49c52642281edacd2740e5a42f7017feea3f2c"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:161c27ccce13b6b0c8689418da3885d3220ed2eae2ea5e9b2f7f3d48f1d52c27"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19910754e4cc9c63bc1c7f6d73aa1cfee82f42007e407c0f413695c2f7ed777f"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:e173486019cc283dc9778315fa29a363579372fe67045e971e89b6365cc035ed"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:af52d26579b308921b73b956153066481f064875140ccd1dfd4e77db89dbb12f"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:981fb88516bd1ae8b0cbbd2034678a39dedc98752f264ac9bc5839d3923fa04c"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5fde892e6c697ce3e30c61b239330fc5d569a71fefd4eb6512fc6caec9dd9e2f"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:816f5aa087094099fff7edabb5e01cc370eb21aa1a1d44fe2d2aefdfb5599b31"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c10c309e18e443ddb108f0ef64e8729363adbfd92d6d57beec680f6261556f3"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98476c98b02c8e9b2eec76ac4156fd006628b1b2d0ef27e548ffa978393fd154"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c3027001c28434e7ca5a6e1e527487051136aa81803ac812be51802150d880dd"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7699b1df36a48169cdebda7ab5a2bac265204003f153b4bd17276153d997670a"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1c39b07d90be6b48968ddc8c19e7585052088fd7ec8d568bb31ff64c70ae3c97"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:46ccfe3032b3915586e469d4972973f893c0a2bb65669194a5bdea9bacc088c2"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840"}, + {file = "pydantic_core-2.27.1.tar.gz", hash = "sha256:62a763352879b84aa31058fc931884055fd75089cccbd9d58bb6afd01141b235"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + +[[package]] +name = "transforms3d" +version = "0.4.2" +description = "Functions for 3D coordinate transformations" +optional = false +python-versions = ">=3.6" +files = [ + {file = "transforms3d-0.4.2.tar.gz", hash = "sha256:e8b5df30eaedbee556e81c6938e55aab5365894e47d0a17615d7db7fd2393680"}, +] + +[package.dependencies] +numpy = ">=1.15" + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] [metadata] lock-version = "2.0" python-versions = ">=3.12.0,<4.0" -content-hash = "55077cf34bc451233d3044bf620c6a190f4462bc9d3fca046c9e4a6636be1781" +content-hash = "8ca7bb04e89b9682f0c90e7535ee55bb0dd91e5c08cf9d72ae4fd698e8bcc3e7" diff --git a/pkgs/node_helpers/pyproject.toml b/pkgs/node_helpers/pyproject.toml index a8e9d7a..78e1c63 100644 --- a/pkgs/node_helpers/pyproject.toml +++ b/pkgs/node_helpers/pyproject.toml @@ -7,14 +7,22 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.12.0,<4.0" +numpy = "^2.1.3" +pydantic = "^2.10.1" +pyyaml = "^6.0.2" + +# For ros2_numpy, this brings in a newer version of transforms3d that has been updated +# to work with numpy 2.0.0+ +transforms3d = "^0.4.2" [tool.poetry.scripts] -# Each entry here will create an executable which can be referenced in launchfiles -ExampleNode = "node_helpers.nodes.node_helpers_node:main" +interactive_transform_publisher = "node_helpers.nodes.interactive_transform_publisher:main" +sound_player = "node_helpers.nodes.sound_player:main" +placeholder = "node_helpers.nodes.placeholder:main" [tool.colcon-poetry-ros.data-files] "share/ament_index/resource_index/packages" = ["resource/node_helpers"] -"share/camera_drivers" = ["package.xml"] +"share/node_helpers" = ["package.xml"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/pkgs/node_helpers_msgs/CMakeLists.txt b/pkgs/node_helpers_msgs/CMakeLists.txt new file mode 100644 index 0000000..65b81db --- /dev/null +++ b/pkgs/node_helpers_msgs/CMakeLists.txt @@ -0,0 +1,58 @@ +cmake_minimum_required(VERSION 3.5) +project(node_helpers_msgs) + +# Default to C99 +if(NOT CMAKE_C_STANDARD) + set(CMAKE_C_STANDARD 99) +endif() + +# Default to C++14 +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 14) +endif() + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +find_package(ament_cmake REQUIRED) +find_package(rosidl_default_generators REQUIRED) +find_package(std_msgs REQUIRED) +find_package(geometry_msgs REQUIRED) + +###################################################################### +########## NOTE TO TEMPLATE USERS: When adding messages that require other messages, +########## add the following line below to ensure that the other messages are found. +########## Also make sure new messages are added to the rosidl_generate_interfaces list. +# find_package(my_required_msgs_package REQUIRED) +###################################################################### + + +rosidl_generate_interfaces( + ${PROJECT_NAME} + + ###################################################################### + ########## NOTE TO TEMPLATE USERS: Add new messages here, like so: + # msg/ANewMessageFile.msg + ###################################################################### + + "msg/BinaryReading.msg" + "msg/PlaySound.msg" + "msg/PromptOption.msg" + "msg/RangefinderReading.msg" + "msg/SensorExample.msg" + "msg/UserPrompt.msg" + + "action/RobustActionExample.action" + + "srv/ChoosePromptOption.srv" + "srv/RobustServiceExample.srv" + + ###################################################################### + ########## NOTE TO TEMPLATE USERS: Add your dependency packages here, like so: + # DEPENDENCIES my_required_msgs_package + ###################################################################### + DEPENDENCIES std_msgs geometry_msgs +) + +ament_package() diff --git a/pkgs/node_helpers_msgs/action/RobustActionExample.action b/pkgs/node_helpers_msgs/action/RobustActionExample.action new file mode 100644 index 0000000..79eee7e --- /dev/null +++ b/pkgs/node_helpers_msgs/action/RobustActionExample.action @@ -0,0 +1,7 @@ +# This message is for testing RobustRPC actions + +--- +string data +string error_name +string error_description +--- diff --git a/pkgs/node_helpers_msgs/msg/BinaryReading.msg b/pkgs/node_helpers_msgs/msg/BinaryReading.msg new file mode 100644 index 0000000..26a2d34 --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/BinaryReading.msg @@ -0,0 +1,10 @@ +# This message is used to publish binary value sensor messages that follow the BaseSensor message structure + +std_msgs/Header header + +bool value + +# A few constants to help readability in code for different sensor use cases: +# For binary presence detection sensors +bool BLOCKED = true +bool UNBLOCKED = false \ No newline at end of file diff --git a/pkgs/node_helpers_msgs/msg/PlaySound.msg b/pkgs/node_helpers_msgs/msg/PlaySound.msg new file mode 100644 index 0000000..7f4e5b1 --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/PlaySound.msg @@ -0,0 +1,4 @@ +# This message can be published to the SoundPlayer node, to, uhh, have sounds played! + +# The sound filename to play, which the SoundPlayer can find in its sound directory +string sound_filename diff --git a/pkgs/node_helpers_msgs/msg/PromptOption.msg b/pkgs/node_helpers_msgs/msg/PromptOption.msg new file mode 100644 index 0000000..555bd99 --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/PromptOption.msg @@ -0,0 +1,6 @@ +# Used by `node_helpers.interaction` + +string name + +# Give a description for the option. +string description \ No newline at end of file diff --git a/pkgs/node_helpers_msgs/msg/RangefinderReading.msg b/pkgs/node_helpers_msgs/msg/RangefinderReading.msg new file mode 100644 index 0000000..d23efee --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/RangefinderReading.msg @@ -0,0 +1,5 @@ +# This message is used to publish rangefinder readings, and follows the BaseSensor protocol +std_msgs/Header header + +# Vector value in meters +geometry_msgs/Vector3 value diff --git a/pkgs/node_helpers_msgs/msg/SensorExample.msg b/pkgs/node_helpers_msgs/msg/SensorExample.msg new file mode 100644 index 0000000..2170c70 --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/SensorExample.msg @@ -0,0 +1,8 @@ +# This message is used as an example of a sensor message. It's used for tests as well. + +# Each sensor has a 'header', where the 'frame_id' represents the place where the +# sensor is located at origin +std_msgs/Header header + +# The value can be of any type, even made up of other message types. +int64 value \ No newline at end of file diff --git a/pkgs/node_helpers_msgs/msg/UserPrompt.msg b/pkgs/node_helpers_msgs/msg/UserPrompt.msg new file mode 100644 index 0000000..c8035fd --- /dev/null +++ b/pkgs/node_helpers_msgs/msg/UserPrompt.msg @@ -0,0 +1,25 @@ +# A prompt that only provides options to a user on some arbitrary menu implementation +# Used by `node_helpers.interaction` + + +# Metadata schema: +# No members +uint8 PROMPT_BASIC = 0 +# A prompt used to teleop a hardware component. When the dashboard receives a +# prompt with this type, it will display menus to teleop the specified +# components. +# +# Metadata schema: +# namespaces - An array of hardware module namespaces, specifying what +# hardware should be teleoped +uint8 PROMPT_TELEOP = 1 + +# Available options to transition to other states +PromptOption[] options +# Provides additional context about the prompt +string help +# The prompt type. Effects what additional options are presented to the user by +# the dashboard +uint8 type 0 +# A JSON object whose contents depend on the type of prompt being used +string metadata diff --git a/pkgs/node_helpers_msgs/package.xml b/pkgs/node_helpers_msgs/package.xml new file mode 100644 index 0000000..6800c31 --- /dev/null +++ b/pkgs/node_helpers_msgs/package.xml @@ -0,0 +1,24 @@ + + + + node_helpers_msgs + 0.5.0 + Defines messages used by the 'node_helpers project. + urbanmachine + + MIT + + ament_cmake + + rosidl_default_generators + std_msgs + geometry_msgs + + rosidl_default_runtime + + rosidl_interface_packages + + + ament_cmake + + diff --git a/pkgs/node_helpers_msgs/srv/ChoosePromptOption.srv b/pkgs/node_helpers_msgs/srv/ChoosePromptOption.srv new file mode 100644 index 0000000..b7da9cf --- /dev/null +++ b/pkgs/node_helpers_msgs/srv/ChoosePromptOption.srv @@ -0,0 +1,9 @@ +# Choose a prompt option from a given menu +# Used by `node_helpers.interaction` + +PromptOption option + +--- + +string error_name +string error_description diff --git a/pkgs/node_helpers_msgs/srv/RobustServiceExample.srv b/pkgs/node_helpers_msgs/srv/RobustServiceExample.srv new file mode 100644 index 0000000..13e0018 --- /dev/null +++ b/pkgs/node_helpers_msgs/srv/RobustServiceExample.srv @@ -0,0 +1,5 @@ +# This message is for testing RobustRPC services +--- +string data +string error_name +string error_description \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6ef3560..cc8b390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "node_helpers" -version = "1.0.0" +version = "0.5.0" description = "An opinionated ROS2 framework that minimizes boilerplate while maximizing reliability. Features intuitive APIs for parameter management, action handling, and error-resilient RPC. Designed by Urban Machine for safe and scalable robotics." authors = ["urbanmachine "] license = "MIT"