diff --git a/pyproject.toml b/pyproject.toml index 5e229ba9e..4c31bb012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,10 @@ line-length = 88 [tool.isort] profile = "black" multi_line_output = 3 + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-rsxX -v --strict-markers --forked" +testpaths = [ + "tests", +] diff --git a/setup.cfg b/setup.cfg index d3aa7c991..b6eed46e7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,14 +53,13 @@ package_dir = python_requires = >=3.10 install_requires = coloredlogs - jax >= 0.4.1, <0.4.11 - jaxlib < 0.4.11 + jax >= 0.4.1 + jaxlib jaxlie jax_dataclasses >= 1.4.0 - ml-dtypes < 0.3.0 pptree rod - scipy + typing_extensions; python_version < "3.11" [options.packages.find] where = src @@ -71,13 +70,10 @@ style = isort testing = idyntree - pytest + pytest >= 6.0 + pytest-forked pytest-icdiff robot-descriptions all = %(style)s %(testing)s - -[tool:pytest] -addopts = -rsxX -v --strict-markers -testpaths = tests diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index 7ad51eb26..4064737a8 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -60,5 +60,7 @@ def _is_editable() -> bool: del _np_options del _is_editable -from . import high_level, logging, math, sixd +from . import high_level, logging, math, simulation, sixd +from .high_level.common import VelRepr +from .simulation.ode_integration import IntegratorType from .simulation.simulator import JaxSim diff --git a/src/jaxsim/high_level/__init__.py b/src/jaxsim/high_level/__init__.py index c70fae956..8d485d4a5 100644 --- a/src/jaxsim/high_level/__init__.py +++ b/src/jaxsim/high_level/__init__.py @@ -1 +1,2 @@ from . import common, joint, link, model +from .common import VelRepr diff --git a/src/jaxsim/high_level/joint.py b/src/jaxsim/high_level/joint.py index 16417169e..aef8b8360 100644 --- a/src/jaxsim/high_level/joint.py +++ b/src/jaxsim/high_level/joint.py @@ -1,73 +1,127 @@ import dataclasses -from typing import Any, Tuple +import functools +from typing import Any +import jax.numpy as jnp import jax_dataclasses from jax_dataclasses import Static import jaxsim.parsers import jaxsim.typing as jtp -from jaxsim.utils import JaxsimDataclass +from jaxsim.utils import Vmappable, not_tracing, oop @jax_dataclasses.pytree_dataclass -class Joint(JaxsimDataclass): +class Joint(Vmappable): """ - High-level class to operate on a single joint of a simulated model. + High-level class to operate in r/o on a single joint of a simulated model. """ joint_description: Static[jaxsim.parsers.descriptions.JointDescription] - _parent_model: Any = dataclasses.field(default=None, repr=False, compare=False) + _parent_model: Any = dataclasses.field( + default=None, repr=False, compare=False, hash=False + ) @property def parent_model(self) -> "jaxsim.high_level.model.Model": + """""" + return self._parent_model - def valid(self) -> bool: - return self.parent_model is not None + @functools.partial(oop.jax_tf.method_ro, jit=False) + def valid(self) -> jtp.Bool: + """""" + + return jnp.array(self.parent_model is not None, dtype=bool) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def index(self) -> jtp.Int: + """""" + + return jnp.array(self.joint_description.index, dtype=int) - def index(self) -> int: - return self.joint_description.index + @functools.partial(oop.jax_tf.method_ro) + def dofs(self) -> jtp.Int: + """""" - def dofs(self) -> int: - return 1 + return jnp.array(1, dtype=int) + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def name(self) -> str: + """""" + return self.joint_description.name - def position(self, dof: int = 0) -> float: - return self.parent_model.joint_positions(joint_names=[self.name()])[dof] + @functools.partial(oop.jax_tf.method_ro) + def position(self, dof: int = None) -> jtp.Float: + """""" + + dof = dof if dof is not None else 0 + + return jnp.array( + self.parent_model.joint_positions(joint_names=(self.name(),))[dof], + dtype=float, + ) + + @functools.partial(oop.jax_tf.method_ro) + def velocity(self, dof: int = None) -> jtp.Float: + """""" - def velocity(self, dof: int = 0) -> float: - return self.parent_model.joint_velocities(joint_names=[self.name()])[dof] + dof = dof if dof is not None else 0 - def acceleration(self, dof: int = 0) -> float: - return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof] + return jnp.array( + self.parent_model.joint_velocities(joint_names=(self.name(),))[dof], + dtype=float, + ) - def force(self, dof: int = 0) -> float: - return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[ - dof - ] + @functools.partial(oop.jax_tf.method_ro) + def force_target(self, dof: int = None) -> jtp.Float: + """""" - def position_limit(self, dof: int = 0) -> Tuple[float, float]: - if dof != 0: + dof = dof if dof is not None else 0 + + return jnp.array( + self.parent_model.joint_generalized_forces_targets( + joint_names=(self.name(),) + )[dof], + dtype=float, + ) + + @functools.partial(oop.jax_tf.method_ro) + def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]: + """""" + + dof = dof if dof is not None else 0 + + if not_tracing(dof) and dof != 0: msg = "Only joints with 1 DoF are currently supported" raise ValueError(msg) - return self.joint_description.position_limit + low, high = self.joint_description.position_limit + + return jnp.array(low, dtype=float), jnp.array(high, dtype=float) # ================= # Multi-DoF methods # ================= + @functools.partial(oop.jax_tf.method_ro) def joint_position(self) -> jtp.Vector: - return self.parent_model.joint_positions(joint_names=[self.name()]) + """""" + + return self.parent_model.joint_positions(joint_names=(self.name(),)) + @functools.partial(oop.jax_tf.method_ro) def joint_velocity(self) -> jtp.Vector: - return self.parent_model.joint_velocities(joint_names=[self.name()]) + """""" + + return self.parent_model.joint_velocities(joint_names=(self.name(),)) - def joint_acceleration(self) -> jtp.Vector: - return self.parent_model.joint_accelerations(joint_names=[self.name()]) + @functools.partial(oop.jax_tf.method_ro) + def joint_force_target(self) -> jtp.Vector: + """""" - def joint_force(self) -> jtp.Vector: - return self.parent_model.joint_generalized_forces(joint_names=[self.name()]) + return self.parent_model.joint_generalized_forces_targets( + joint_names=(self.name(),) + ) diff --git a/src/jaxsim/high_level/link.py b/src/jaxsim/high_level/link.py index 203d40795..5a9010597 100644 --- a/src/jaxsim/high_level/link.py +++ b/src/jaxsim/high_level/link.py @@ -1,6 +1,8 @@ import dataclasses +import functools from typing import Any +import jax.lax import jax.numpy as jnp import jax_dataclasses import numpy as np @@ -10,95 +12,146 @@ import jaxsim.sixd as sixd import jaxsim.typing as jtp from jaxsim.physics.algos.jacobian import jacobian -from jaxsim.utils import JaxsimDataclass +from jaxsim.utils import Vmappable, oop from .common import VelRepr @jax_dataclasses.pytree_dataclass -class Link(JaxsimDataclass): +class Link(Vmappable): """ - High-level class to operate on a single link of a simulated model. + High-level class to operate in r/o on a single link of a simulated model. """ link_description: Static[jaxsim.parsers.descriptions.LinkDescription] - _parent_model: Any = dataclasses.field(default=None, repr=False, compare=False) + _parent_model: Any = dataclasses.field( + default=None, repr=False, compare=False, hash=False + ) @property def parent_model(self) -> "jaxsim.high_level.model.Model": + """""" + return self._parent_model - def valid(self) -> bool: - return self.parent_model is not None + @functools.partial(oop.jax_tf.method_ro, jit=False) + def valid(self) -> jtp.Bool: + """""" + + return jnp.array(self.parent_model is not None, dtype=bool) # ========== # Properties # ========== + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def name(self) -> str: + """""" + return self.link_description.name - def index(self) -> int: - return self.link_description.index + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def index(self) -> jtp.Int: + """""" + + return jnp.array(self.link_description.index, dtype=int) # ======== # Dynamics # ======== + @functools.partial(oop.jax_tf.method_ro, jit=False) def mass(self) -> jtp.Float: - return self.link_description.mass + """""" + return jnp.array(self.link_description.mass, dtype=float) + + @functools.partial(oop.jax_tf.method_ro, jit=False) def spatial_inertia(self) -> jtp.Matrix: - return self.link_description.inertia + """""" + + return jnp.array(self.link_description.inertia, dtype=float) + + @functools.partial(oop.jax_tf.method_ro, vmap_in_axes=(0, None)) + def com_position(self, in_link_frame: bool = True) -> jtp.Vector: + """""" - def com_position(self, in_link_frame: bool = True) -> jtp.VectorJax: from jaxsim.math.inertia import Inertia _, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia()) - if in_link_frame: + def com_in_link_frame(): return L_p_CoM.squeeze() - W_H_L = self.transform() - W_ph_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1]) + def com_in_inertial_frame(): + W_H_L = self.transform() + W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1]) + + return W_p̃_CoM[0:3].squeeze() - return W_ph_CoM[0:3].squeeze() + return jax.lax.select( + pred=in_link_frame, + on_true=com_in_link_frame(), + on_false=com_in_inertial_frame(), + ) # ========== # Kinematics # ========== + @functools.partial(oop.jax_tf.method_ro) def position(self) -> jtp.Vector: + """""" + return self.transform()[0:3, 3] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"]) def orientation(self, dcm: bool = False) -> jtp.Vector: + """""" + R = self.transform()[0:3, 0:3] to_wxyz = np.array([3, 0, 1, 2]) return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz] + @functools.partial(oop.jax_tf.method_ro) def transform(self) -> jtp.Matrix: + """""" + return self.parent_model.forward_kinematics()[self.index()] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) def velocity(self, vel_repr: VelRepr = None) -> jtp.Vector: + """""" + v_WL = ( self.jacobian(output_vel_repr=vel_repr) @ self.parent_model.generalized_velocity() ) + return v_WL + @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) def linear_velocity(self, vel_repr: VelRepr = None) -> jtp.Vector: + """""" + return self.velocity(vel_repr=vel_repr)[0:3] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) def angular_velocity(self, vel_repr: VelRepr = None) -> jtp.Vector: + """""" + return self.velocity(vel_repr=vel_repr)[3:6] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix: + """""" + if output_vel_repr is None: output_vel_repr = self.parent_model.velocity_representation - # Return the doubly left-trivialized free-floating jacobian + # Compute the doubly left-trivialized free-floating jacobian L_J_WL_B = jacobian( model=self.parent_model.physics_model, body_index=self.index(), @@ -114,7 +167,14 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix: B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() zero_6n = jnp.zeros(shape=(6, dofs)) - B_T_W = jnp.block([[B_X_W, zero_6n], [zero_6n.T, jnp.eye(dofs)]]) + + B_T_W = jnp.vstack( + [ + jnp.block([B_X_W, zero_6n]), + jnp.block([zero_6n.T, jnp.eye(dofs)]), + ] + ) + L_J_WL_target = L_J_WL_B @ B_T_W elif self.parent_model.velocity_representation is VelRepr.Mixed: @@ -124,7 +184,13 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix: B_X_BW = sixd.se3.SE3.from_matrix(BW_H_B).inverse().adjoint() zero_6n = jnp.zeros(shape=(6, dofs)) - B_T_BW = jnp.block([[B_X_BW, zero_6n], [zero_6n.T, jnp.eye(dofs)]]) + + B_T_BW = jnp.vstack( + [ + jnp.block([B_X_BW, zero_6n]), + jnp.block([zero_6n.T, jnp.eye(dofs)]), + ] + ) L_J_WL_target = L_J_WL_B @ B_T_BW @@ -148,6 +214,7 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix: else: raise ValueError(output_vel_repr) + @functools.partial(oop.jax_tf.method_ro) def external_force(self) -> jtp.Vector: """ Return the active external force acting on the link. @@ -160,91 +227,33 @@ def external_force(self) -> jtp.Vector: The active external 6D force acting on the link in the active representation. """ + # Get the external force stored in the inertial representation W_f_ext = self.parent_model.data.model_input.f_ext[self.index()] - return self.parent_model.inertial_to_active_representation( - array=W_f_ext, is_force=True - ) - - def add_external_force( - self, force: jtp.Array = None, torque: jtp.Array = None - ) -> None: - force = force if force is not None else jnp.zeros(3) - torque = torque if torque is not None else jnp.zeros(3) - - f_ext = jnp.hstack([force, torque]) - + # Express it in the active representation if self.parent_model.velocity_representation is VelRepr.Inertial: - W_f_ext = f_ext + f_ext = W_f_ext elif self.parent_model.velocity_representation is VelRepr.Body: - L_f_ext = f_ext W_H_L = self.transform() - L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint() + W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint() - W_f_ext = L_X_W.transpose() @ L_f_ext + f_ext = L_f_ext = W_X_L.transpose() @ W_f_ext elif self.parent_model.velocity_representation is VelRepr.Mixed: - LW_f_ext = f_ext - W_p_L = self.transform()[0:3, 3] W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L) - LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint() - - W_f_ext = LW_X_W @ LW_f_ext - - else: - raise ValueError(self.parent_model.velocity_representation) - - W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :] - - self.parent_model.data.model_input.f_ext = ( - self.parent_model.data.model_input.f_ext.at[self.index(), :].set( - W_f_ext_current + W_f_ext - ) - ) - - def add_com_external_force( - self, force: jtp.Array = None, torque: jtp.Array = None - ) -> None: - force = force if force is not None else jnp.zeros(3) - torque = torque if torque is not None else jnp.zeros(3) - - f_ext = jnp.hstack([force, torque]) - - if self.parent_model.velocity_representation is VelRepr.Inertial: - W_f_ext = f_ext - - elif self.parent_model.velocity_representation is VelRepr.Body: - GL_f_ext = f_ext + W_X_LW = sixd.se3.SE3.from_matrix(W_H_LW).adjoint() - W_H_L = self.transform() - L_p_CoM = self.com_position(in_link_frame=True) - L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM) - W_H_GL = W_H_L @ L_H_GL - GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint() - - W_f_ext = GL_X_W.transpose() @ GL_f_ext - - elif self.parent_model.velocity_representation is VelRepr.Mixed: - GW_f_ext = f_ext - - W_p_CoM = self.com_position(in_link_frame=False) - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint() - - W_f_ext = GW_X_W.transpose() @ GW_f_ext + f_ext = LW_f_ext = W_X_LW.transpose() @ W_f_ext else: raise ValueError(self.parent_model.velocity_representation) - W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :] - - self.parent_model.data.model_input.f_ext = ( - self.parent_model.data.model_input.f_ext.at[self.index(), :].set( - W_f_ext_current + W_f_ext - ) - ) + return f_ext + @functools.partial(oop.jax_tf.method_ro) def in_contact(self) -> jtp.Bool: + """""" + return self.parent_model.in_contact()[self.index()] diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index 94eb877b6..db64e4294 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -1,4 +1,5 @@ import dataclasses +import functools import pathlib from typing import Any, Dict, List, Optional, Tuple, Union @@ -19,13 +20,9 @@ from jaxsim import high_level, logging, physics, sixd from jaxsim.physics.algos import soft_contacts from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.simulation import ode_data, ode_integration -from jaxsim.simulation.ode_integration import IntegratorType -from jaxsim.utils import JaxsimDataclass, Mutability +from jaxsim.utils import JaxsimDataclass, Mutability, Vmappable, oop from .common import VelRepr -from .joint import Joint -from .link import Link @jax_dataclasses.pytree_dataclass @@ -97,7 +94,7 @@ class StepData(JaxsimDataclass): @jax_dataclasses.pytree_dataclass -class Model(JaxsimDataclass): +class Model(Vmappable): """ High-level class to operate on a simulated model. """ @@ -110,14 +107,6 @@ class Model(JaxsimDataclass): velocity_representation: Static[VelRepr] = dataclasses.field(default=VelRepr.Mixed) - _links: Static[Dict[str, Link]] = dataclasses.field( - default_factory=list, repr=False - ) - - _joints: Static[Dict[str, Joint]] = dataclasses.field( - default_factory=list, repr=False - ) - data: ModelData = dataclasses.field(default=None, repr=False) # ======================== @@ -228,81 +217,44 @@ def build( model_name if model_name is not None else physics_model.description.name ) - # Sort all the joints by their index - sorted_links = { - l.name: high_level.link.Link(link_description=l) - for l in sorted( - physics_model.description.links_dict.values(), key=lambda l: l.index - ) - } - - # Sort all the joints by their index - sorted_joints = { - j.name: high_level.joint.Joint(joint_description=j) - for j in sorted( - physics_model.description.joints_dict.values(), - key=lambda j: j.index, - ) - } - # Build the high-level model model = Model( physics_model=physics_model, model_name=model_name, velocity_representation=vel_repr, - _links=sorted_links, - _joints=sorted_joints, ) # Zero the model data - with model.editable(validate=False) as model: + with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): model.zero() # Check model validity if not model.valid(): - raise RuntimeError + raise RuntimeError("The model is not valid.") # Return the high-level model return model - def __post_init__(self): - """Post-init logic. Use the static methods to build high-level models.""" - - original_mutability = self._mutability() - self._set_mutability(Mutability.MUTABLE_NO_VALIDATION) - - for l in self._links.values(): - l.mutable(validate=False)._parent_model = self - - for j in self._joints.values(): - j.mutable(validate=False)._parent_model = self - - self._links: Dict[str, high_level.link.Link] = { - k: v for k, v in sorted(self._links.items(), key=lambda kv: kv[1].index()) - } - - self._joints: Dict[str, high_level.joint.Joint] = { - k: v for k, v in sorted(self._joints.items(), key=lambda kv: kv[1].index()) - } - - self._set_mutability(original_mutability) - + @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) def reduce( - self, considered_joints: List[str], keep_base_pose: bool = False + self, considered_joints: tuple[str, ...], keep_base_pose: bool = False ) -> None: """ Reduce the model by lumping together the links connected by removed joints. Args: - considered_joints: The list of joints to consider. + considered_joints: The sequence of joints to consider. keep_base_pose: A flag indicating whether to keep the base pose or not. """ + if self.vectorized: + raise RuntimeError("Cannot reduce a vectorized model.") + # Reduce the model description. # If considered_joints contains joints not existing in the model, the method # will raise an exception. reduced_model_description = self.physics_model.description.reduce( - considered_joints=considered_joints + considered_joints=list(considered_joints) ) # Create the physics model from the reduced model description @@ -324,36 +276,39 @@ def reduce( # Replace the current model with the reduced model. # Since the structure of the PyTree changes, we disable validation. - with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - self.physics_model = reduced_model.physics_model - self.data = reduced_model.data - self._links = reduced_model._links - self._joints = reduced_model._joints + self.physics_model = reduced_model.physics_model + self.data = reduced_model.data if keep_base_pose: - with self.mutable_context(mutability=Mutability.MUTABLE): - self.reset_base_position(position=W_p_B) - self.reset_base_orientation(orientation=W_Q_B, dcm=False) + self.reset_base_position(position=W_p_B) + self.reset_base_orientation(orientation=W_Q_B, dcm=False) + @functools.partial(oop.jax_tf.method_rw, jit=False) def zero(self) -> None: + """""" + self.data = ModelData.zero(physics_model=self.physics_model) - self.data._set_mutability(self._mutability()) + @functools.partial(oop.jax_tf.method_rw, jit=False) def zero_input(self) -> None: + """""" + self.data.model_input = ModelData.zero( physics_model=self.physics_model ).model_input - self.data._set_mutability(self._mutability()) - + @functools.partial(oop.jax_tf.method_rw, jit=False) def zero_state(self) -> None: + """""" + model_data_zero = ModelData.zero(physics_model=self.physics_model) self.data.model_state = model_data_zero.model_state self.data.contact_state = model_data_zero.contact_state - self.data._set_mutability(self._mutability()) - + @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False) def set_velocity_representation(self, vel_repr: VelRepr) -> None: + """""" + if self.velocity_representation is vel_repr: return @@ -363,72 +318,140 @@ def set_velocity_representation(self, vel_repr: VelRepr) -> None: # Properties # ========== - def valid(self) -> bool: + @functools.partial(oop.jax_tf.method_ro, jit=False) + def valid(self) -> jtp.Bool: + """""" + valid = True valid = valid and all([l.valid() for l in self.links()]) valid = valid and all([j.valid() for j in self.joints()]) - return valid + return jnp.array(valid, dtype=bool) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def floating_base(self) -> jtp.Bool: + """""" - def floating_base(self) -> bool: - return self.physics_model.is_floating_base + return jnp.array(self.physics_model.is_floating_base, dtype=bool) - def dofs(self) -> int: - return self.physics_model.dofs() + @functools.partial(oop.jax_tf.method_ro, jit=False) + def dofs(self) -> jtp.Int: + """""" + return self.joint_positions().size + + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def name(self) -> str: + """""" + return self.model_name - def nr_of_links(self) -> int: - return len(self._links) + @functools.partial(oop.jax_tf.method_ro, jit=False) + def nr_of_links(self) -> jtp.Int: + """""" + + return jnp.array(len(self.links()), dtype=int) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def nr_of_joints(self) -> jtp.Int: + """""" - def nr_of_joints(self) -> int: - return len(self._joints) + return jnp.array(len(self.joints()), dtype=int) + @functools.partial(oop.jax_tf.method_ro) def total_mass(self) -> jtp.Float: - return jnp.sum(jnp.array([l.mass() for l in self.links()])) + """""" + + return jnp.sum(jnp.array([l.mass() for l in self.links()]), dtype=float) + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def get_link(self, link_name: str) -> high_level.link.Link: + """""" + if link_name not in self.link_names(): msg = f"Link '{link_name}' is not part of model '{self.name()}'" raise ValueError(msg) - return self.links(link_names=[link_name])[0] + return self.links(link_names=(link_name,))[0] + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def get_joint(self, joint_name: str) -> high_level.joint.Joint: + """""" + if joint_name not in self.joint_names(): msg = f"Joint '{joint_name}' is not part of model '{self.name()}'" raise ValueError(msg) - return self.joints(joint_names=[joint_name])[0] + return self.joints(joint_names=(joint_name,))[0] + + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def link_names(self) -> tuple[str, ...]: + """""" + + return tuple(l.name() for l in self.links()) - def link_names(self) -> List[str]: - return list(self._links.keys()) + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def joint_names(self) -> tuple[str, ...]: + """""" + + return tuple(j.name() for j in self.joints()) + + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def links( + self, link_names: tuple[str, ...] = None + ) -> tuple[high_level.link.Link, ...]: + """""" + + all_links = { + l.name: high_level.link.Link( + link_description=l, _parent_model=self, batch_size=self.batch_size + ) + for l in sorted( + self.physics_model.description.links_dict.values(), + key=lambda l: l.index, + ) + } - def joint_names(self) -> List[str]: - return list(self._joints.keys()) + for l in all_links.values(): + l._set_mutability(self._mutability()) - def links(self, link_names: List[str] = None) -> List[high_level.link.Link]: if link_names is None: - return list(self._links.values()) + return tuple(all_links.values()) - return [self._links[name] for name in link_names] + return tuple(all_links[name] for name in link_names) + + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def joints( + self, joint_names: tuple[str, ...] = None + ) -> tuple[high_level.joint.Joint, ...]: + """""" + + all_joints = { + j.name: high_level.joint.Joint( + joint_description=j, _parent_model=self, batch_size=self.batch_size + ) + for j in sorted( + self.physics_model.description.joints_dict.values(), + key=lambda j: j.index, + ) + } + + for j in all_joints.values(): + j._set_mutability(self._mutability()) - def joints(self, joint_names: List[str] = None) -> List[high_level.joint.Joint]: if joint_names is None: - return list(self._joints.values()) + return tuple(all_joints.values()) - return [self._joints[name] for name in joint_names] + return tuple(all_joints[name] for name in joint_names) + @functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"]) def in_contact( - self, - link_names: Optional[List[str]] = None, - terrain: Terrain = FlatTerrain(), + self, link_names: tuple[str, ...] = None, terrain: Terrain = FlatTerrain() ) -> jtp.Vector: """""" link_names = link_names if link_names is not None else self.link_names() - if set(link_names) - set(self._links.keys()) != set(): + if set(link_names) - set(self.link_names()) != set(): raise ValueError("One or more link names are not part of the model") from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel @@ -454,25 +477,26 @@ def in_contact( return links_in_contact - # ================== - # Vectorized methods - # ================== + # ================= + # Multi-DoF methods + # ================= - def joint_positions(self, joint_names: List[str] = None) -> jtp.Vector: - if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0): - return jnp.array([]) + @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) + def joint_positions(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: + """""" return self.data.model_state.joint_positions[ self._joint_indices(joint_names=joint_names) ] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_random_positions( - self, - joint_names: List[str] = None, - key: jax.random.PRNGKeyArray = jax.random.PRNGKey(seed=0), + self, joint_names: tuple[str, ...] = None, key: jax.Array = None ) -> jtp.Vector: - if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0): - return jnp.array([]) + """""" + + if key is None: + key = jax.random.PRNGKey(seed=0) s_min, s_max = self.joint_limits(joint_names=joint_names) @@ -485,70 +509,99 @@ def joint_random_positions( return s_random - def joint_velocities(self, joint_names: List[str] = None) -> jtp.Vector: - if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0): - return jnp.array([]) + @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) + def joint_velocities(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: + """""" return self.data.model_state.joint_velocities[ self._joint_indices(joint_names=joint_names) ] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_generalized_forces_targets( - self, joint_names: List[str] = None + self, joint_names: tuple[str, ...] = None ) -> jtp.Vector: - if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0): - return jnp.array([]) + """""" return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)] + @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_limits( - self, joint_names: List[str] = None + self, joint_names: tuple[str, ...] = None ) -> Tuple[jtp.Vector, jtp.Vector]: - if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0): - return jnp.array([]) + """""" + # Consider all joints if not specified otherwise joint_names = joint_names if joint_names is not None else self.joint_names() - s_min = jnp.array( - [min(self.get_joint(name).position_limit()) for name in joint_names] + # Create a (Dofs, 2) matrix containing the joint limits + limits = jnp.vstack( + jnp.array([j.position_limit() for j in self.joints(joint_names)]) ) - s_max = jnp.array( - [max(self.get_joint(name).position_limit()) for name in joint_names] - ) + # Get the limits, reordering them in case low > high + s_low = jnp.min(limits, axis=1) + s_high = jnp.max(limits, axis=1) - return s_min, s_max + return s_low, s_high # ========= # Base link # ========= + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def base_frame(self) -> str: + """""" + return self.physics_model.description.root.name + @functools.partial(oop.jax_tf.method_ro) def base_position(self) -> jtp.Vector: + """""" + return self.data.model_state.base_position.squeeze() + @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"]) def base_orientation(self, dcm: bool = False) -> jtp.Vector: + """""" + + # Normalize the quaternion before using it. + # Our integration logic has a Baumgarte stabilization term makes the quaternion + # norm converge to 1, but it does not enforce to be 1 at all the time instants. + base_unit_quaternion = ( + self.data.model_state.base_quaternion.squeeze() + / jnp.linalg.norm(self.data.model_state.base_quaternion) + ) + + # wxyz -> xyzw to_xyzw = np.array([1, 2, 3, 0]) return ( - self.data.model_state.base_quaternion + base_unit_quaternion if not dcm else sixd.so3.SO3.from_quaternion_xyzw( - self.data.model_state.base_quaternion[to_xyzw] + base_unit_quaternion[to_xyzw] ).as_matrix() ) + @functools.partial(oop.jax_tf.method_ro) def base_transform(self) -> jtp.MatrixJax: - return jnp.block( + """""" + + W_R_B = self.base_orientation(dcm=True) + W_p_B = jnp.vstack(self.base_position()) + + return jnp.vstack( [ - [self.base_orientation(dcm=True), jnp.vstack(self.base_position())], - [0, 0, 0, 1], + jnp.block([W_R_B, W_p_B]), + jnp.array([0, 0, 0, 1]), ] ) + @functools.partial(oop.jax_tf.method_ro) def base_velocity(self) -> jtp.Vector: + """""" + W_v_WB = jnp.hstack( [ self.data.model_state.base_linear_velocity, @@ -558,6 +611,7 @@ def base_velocity(self) -> jtp.Vector: return self.inertial_to_active_representation(array=W_v_WB) + @functools.partial(oop.jax_tf.method_ro) def external_forces(self) -> jtp.Matrix: """ Return the active external forces acting on the robot. @@ -581,16 +635,141 @@ def external_forces(self) -> jtp.Matrix: return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext) - # ================== - # Dynamic properties - # ================== + # ======================= + # Single link r/w methods + # ======================= + @functools.partial( + oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"] + ) + def apply_external_force_to_link( + self, + link_name: str, + force: jtp.Array = None, + torque: jtp.Array = None, + additive: bool = True, + ) -> None: + """""" + + # Get the target link with the correct mutability + link = self.get_link(link_name=link_name) + link._set_mutability(mutability=self._mutability()) + + # Initialize zero force components if not set + force = force if force is not None else jnp.zeros(3) + torque = torque if torque is not None else jnp.zeros(3) + + # Build the target 6D force in the active representation + f_ext = jnp.hstack([force, torque]) + + # Convert the 6D force to the inertial representation + if self.velocity_representation is VelRepr.Inertial: + W_f_ext = f_ext + + elif self.velocity_representation is VelRepr.Body: + L_f_ext = f_ext + W_H_L = link.transform() + L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint() + + W_f_ext = L_X_W.transpose() @ L_f_ext + + elif self.velocity_representation is VelRepr.Mixed: + LW_f_ext = f_ext + + W_p_L = link.transform()[0:3, 3] + W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L) + LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint() + + W_f_ext = LW_X_W.transpose() @ LW_f_ext + + else: + raise ValueError(self.velocity_representation) + + # Obtain the new 6D force considering the 'additive' flag + W_f_ext_current = self.data.model_input.f_ext[link.index(), :] + new_force = W_f_ext_current + W_f_ext if additive else W_f_ext + + # Update the model data + self.data.model_input.f_ext = self.data.model_input.f_ext.at[ + link.index(), : + ].set(new_force) + + @functools.partial( + oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"] + ) + def apply_external_force_to_link_com( + self, + link_name: str, + force: jtp.Array = None, + torque: jtp.Array = None, + additive: bool = True, + ) -> None: + """""" + + # Get the target link with the correct mutability + link = self.get_link(link_name=link_name) + link._set_mutability(mutability=self._mutability()) + + # Initialize zero force components if not set + force = force if force is not None else jnp.zeros(3) + torque = torque if torque is not None else jnp.zeros(3) + + # Build the target 6D force in the active representation + f_ext = jnp.hstack([force, torque]) + + # Convert the 6D force to the inertial representation + if self.velocity_representation is VelRepr.Inertial: + W_f_ext = f_ext + + elif self.velocity_representation is VelRepr.Body: + GL_f_ext = f_ext + + W_H_L = link.transform() + L_p_CoM = link.com_position(in_link_frame=True) + L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM) + W_H_GL = W_H_L @ L_H_GL + GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint() + + W_f_ext = GL_X_W.transpose() @ GL_f_ext + + elif self.velocity_representation is VelRepr.Mixed: + GW_f_ext = f_ext + + W_p_CoM = link.com_position(in_link_frame=False) + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint() + + W_f_ext = GW_X_W.transpose() @ GW_f_ext + + else: + raise ValueError(self.velocity_representation) + + # Obtain the new 6D force considering the 'additive' flag + W_f_ext_current = self.data.model_input.f_ext[link.index(), :] + new_force = W_f_ext_current + W_f_ext if additive else W_f_ext + + # Update the model data + self.data.model_input.f_ext = self.data.model_input.f_ext.at[ + link.index(), : + ].set(new_force) + + # ================================================ + # Generalized methods and free-floating quantities + # ================================================ + + @functools.partial(oop.jax_tf.method_ro) def generalized_position(self) -> Tuple[jtp.Matrix, jtp.Vector]: + """""" + return self.base_transform(), self.joint_positions() + @functools.partial(oop.jax_tf.method_ro) def generalized_velocity(self) -> jtp.Vector: + """""" + return jnp.hstack([self.base_velocity(), self.joint_velocities()]) + @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) def generalized_free_floating_jacobian( self, output_vel_repr: VelRepr = None ) -> jtp.Matrix: @@ -639,7 +818,10 @@ def to_output(W_J_Wi): return J_free_floating + @functools.partial(oop.jax_tf.method_ro) def free_floating_mass_matrix(self) -> jtp.Matrix: + """""" + M_body = jaxsim.physics.algos.crba.crba( model=self.physics_model, q=self.data.model_state.joint_positions, @@ -652,7 +834,12 @@ def free_floating_mass_matrix(self) -> jtp.Matrix: zero_6n = jnp.zeros(shape=(6, self.dofs())) B_X_W = sixd.se3.SE3.from_matrix(self.base_transform()).inverse().adjoint() - invT = jnp.block([[B_X_W, zero_6n], [zero_6n.T, jnp.eye(self.dofs())]]) + invT = jnp.vstack( + [ + jnp.block([B_X_W, zero_6n]), + jnp.block([zero_6n.T, jnp.eye(self.dofs())]), + ] + ) return invT.T @ M_body @ invT @@ -661,49 +848,57 @@ def free_floating_mass_matrix(self) -> jtp.Matrix: W_H_BW = self.base_transform().at[0:3, 3].set(jnp.zeros(3)) BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - invT = jnp.block([[BW_X_W, zero_6n], [zero_6n.T, jnp.eye(self.dofs())]]) + invT = jnp.vstack( + [ + jnp.block([BW_X_W, zero_6n]), + jnp.block([zero_6n.T, jnp.eye(self.dofs())]), + ] + ) return invT.T @ M_body @ invT else: raise ValueError(self.velocity_representation) + @functools.partial(oop.jax_tf.method_ro) def free_floating_bias_forces(self) -> jtp.Vector: - with self.editable(validate=True) as model: - model.zero() + """""" - state = self.data.model_state.copy() - model.data.model_state.base_position = state.base_position - model.data.model_state.base_quaternion = state.base_quaternion - model.data.model_state.joint_positions = state.joint_positions - model.data.model_state.base_linear_velocity = state.base_linear_velocity - model.data.model_state.base_angular_velocity = state.base_angular_velocity - model.data.model_state.joint_velocities = state.joint_velocities + with self.editable(validate=True) as model: + model.zero_input() return jnp.hstack( model.inverse_dynamics( - base_acceleration=jnp.zeros(6), - joint_accelerations=jnp.zeros(model.dofs()), + base_acceleration=jnp.zeros(6), joint_accelerations=None ) ) + @functools.partial(oop.jax_tf.method_ro) def free_floating_gravity_forces(self) -> jtp.Vector: - with self.editable(validate=True) as model: - model.zero() + """""" - state = self.data.model_state.copy() - model.data.model_state.base_position = state.base_position - model.data.model_state.base_quaternion = state.base_quaternion - model.data.model_state.joint_positions = state.joint_positions + with self.editable(validate=True) as model: + model.zero_input() + model.data.model_state.joint_velocities = jnp.zeros_like( + model.data.model_state.joint_velocities + ) + model.data.model_state.base_linear_velocity = jnp.zeros_like( + model.data.model_state.base_linear_velocity + ) + model.data.model_state.base_angular_velocity = jnp.zeros_like( + model.data.model_state.base_angular_velocity + ) return jnp.hstack( model.inverse_dynamics( - base_acceleration=jnp.zeros(6), - joint_accelerations=jnp.zeros(model.dofs()), + base_acceleration=jnp.zeros(6), joint_accelerations=None ) ) + @functools.partial(oop.jax_tf.method_ro) def momentum(self) -> jtp.Vector: + """""" + with self.editable(validate=True) as m: m.set_velocity_representation(vel_repr=VelRepr.Body) @@ -718,19 +913,27 @@ def momentum(self) -> jtp.Vector: W_h = B_X_W.T @ B_h return self.inertial_to_active_representation(array=W_h, is_force=True) - # ============================== - # Quantities related to the CoM - # ============================== + # =========== + # CoM methods + # =========== + @functools.partial(oop.jax_tf.method_ro) def com_position(self) -> jtp.Vector: + """""" + m = self.total_mass() W_H_L = self.forward_kinematics() W_H_B = self.base_transform() - B_H_W = jnp.linalg.inv(W_H_B) + B_H_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().as_matrix() com_links = [ - (l.mass() * B_H_W @ W_H_L[l.index()] @ jnp.hstack([l.com(), 1])) + ( + l.mass() + * B_H_W + @ W_H_L[l.index()] + @ jnp.hstack([l.com_position(in_link_frame=True), 1]) + ) for l in self.links() ] @@ -742,7 +945,10 @@ def com_position(self) -> jtp.Vector: # Algorithms # ========== + @functools.partial(oop.jax_tf.method_ro) def forward_kinematics(self) -> jtp.Array: + """""" + W_H_i = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( model=self.physics_model, q=self.data.model_state.joint_positions, @@ -751,10 +957,11 @@ def forward_kinematics(self) -> jtp.Array: return W_H_i + @functools.partial(oop.jax_tf.method_ro) def inverse_dynamics( self, joint_accelerations: jtp.Vector = None, - base_acceleration: jtp.Vector = jnp.zeros(6), + base_acceleration: jtp.Vector = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics with the RNEA algorithm. @@ -776,8 +983,10 @@ def inverse_dynamics( else jnp.zeros_like(self.joint_positions()) ) - if joint_accelerations.size != self.dofs(): - raise ValueError(joint_accelerations.size) + # Build base acceleration if not provided + base_acceleration = ( + base_acceleration if base_acceleration is not None else jnp.zeros(6) + ) if base_acceleration.size != 6: raise ValueError(base_acceleration.size) @@ -839,23 +1048,29 @@ def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC): return f_B, tau + @functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"]) def forward_dynamics( self, tau: jtp.Vector = None, prefer_aba: float = True ) -> Tuple[jtp.Vector, jtp.Vector]: + """""" + return ( self.forward_dynamics_aba(tau=tau) if prefer_aba else self.forward_dynamics_crb(tau=tau) ) + @functools.partial(oop.jax_tf.method_ro) def forward_dynamics_aba( self, tau: jtp.Vector = None ) -> Tuple[jtp.Vector, jtp.Vector]: + """""" + # Build joint torques if not provided tau = tau if tau is not None else jnp.zeros_like(self.joint_positions()) # Compute ABA - W_v̇_WB, sdd = jaxsim.physics.algos.aba.aba( + W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba( model=self.physics_model, xfb=self.data.model_state.xfb(), q=self.data.model_state.joint_positions, @@ -907,15 +1122,18 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC): ) # Adjust shape - sdd = jnp.atleast_1d(sdd.squeeze()) + s̈ = jnp.atleast_1d(s̈.squeeze()) - return C_v̇_WB, sdd + return C_v̇_WB, s̈ + @functools.partial(oop.jax_tf.method_ro) def forward_dynamics_crb( self, tau: jtp.Vector = None ) -> Tuple[jtp.Vector, jtp.Vector]: + """""" + # Build joint torques if not provided - τ = tau if tau is not None else jnp.zeros_like(self.joint_positions()) + τ = tau if tau is not None else jnp.zeros(shape=(self.dofs(),)) τ = jnp.atleast_1d(τ.squeeze()) τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1)) @@ -926,38 +1144,47 @@ def forward_dynamics_crb( f_ext = jnp.vstack(self.external_forces().flatten()) S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T - # Configure the slice for fixed/floating base robots - sl = np.s_[0:] if self.floating_base() else np.s_[6:] - # Compute the generalized acceleration by inverting the EoM - ν̇ = jnp.linalg.inv(M[sl, sl]) @ ((S @ τ)[sl] - h[sl] + J[:, sl].T @ f_ext) + ν̇ = jax.lax.select( + pred=self.floating_base(), + on_true=jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext), + on_false=jnp.vstack( + [ + jnp.zeros(shape=(6, 1)), + jnp.linalg.inv(M[6:, 6:]) + @ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext), + ] + ), + ).squeeze() # Extract the base acceleration in the active representation. # Note that this is an apparent acceleration (relevant in Mixed representation), # therefore it cannot be always expressed in different frames with just a # 6D transformation X. - a_WB = ν̇[0:6] if self.floating_base() else jnp.zeros(6) + v̇_WB = ν̇[0:6] # Extract the joint accelerations - sdd = ν̇[6:] if self.floating_base() else ν̇ + s̈ = jnp.atleast_1d(ν̇[6:]) - # Adjust shape and convert to lin-ang serialization - a_WB = a_WB.squeeze() - sdd = jnp.atleast_1d(sdd.squeeze()) - - return a_WB, sdd + return v̇_WB, s̈ # ====== # Energy # ====== + @functools.partial(oop.jax_tf.method_ro) def mechanical_energy(self) -> jtp.Float: + """""" + K = self.kinetic_energy() U = self.potential_energy() return K + U + @functools.partial(oop.jax_tf.method_ro) def kinetic_energy(self) -> jtp.Float: + """""" + with self.editable(validate=True) as m: m.set_velocity_representation(vel_repr=VelRepr.Body) @@ -966,7 +1193,10 @@ def kinetic_energy(self) -> jtp.Float: return 0.5 * nu.T @ M @ nu + @functools.partial(oop.jax_tf.method_ro) def potential_energy(self) -> jtp.Float: + """""" + m = self.total_mass() W_p_CoM = jnp.hstack([self.com_position(), 1]) gravity = self.physics_model.gravity[3:6].squeeze() @@ -977,9 +1207,12 @@ def potential_energy(self) -> jtp.Float: # Set targets # =========== + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_joint_generalized_force_targets( - self, forces: jtp.Vector, joint_names: List[str] = None + self, forces: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: + """""" + if joint_names is None: joint_names = self.joint_names() @@ -994,9 +1227,12 @@ def set_joint_generalized_force_targets( # Reset data # ========== + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def reset_joint_positions( - self, positions: jtp.Vector, joint_names: List[str] = None + self, positions: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: + """""" + if joint_names is None: joint_names = self.joint_names() @@ -1017,9 +1253,12 @@ def reset_joint_positions( ) ) + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def reset_joint_velocities( - self, velocities: jtp.Vector, joint_names: List[str] = None + self, velocities: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: + """""" + if joint_names is None: joint_names = self.joint_names() @@ -1040,10 +1279,16 @@ def reset_joint_velocities( ) ) + @functools.partial(oop.jax_tf.method_rw) def reset_base_position(self, position: jtp.Vector) -> None: + """""" + self.data.model_state.base_position = jnp.array(position, dtype=float) + @functools.partial(oop.jax_tf.method_rw, static_argnames=["dcm"]) def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> None: + """""" + if dcm: to_wxyz = np.array([3, 0, 1, 2]) orientation_xyzw = sixd.so3.SO3.from_matrix( @@ -1051,16 +1296,23 @@ def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> N ).as_quaternion_xyzw() orientation = orientation_xyzw[to_wxyz] - self.data.model_state.base_quaternion = jnp.array(orientation, dtype=float) + unit_quaternion = orientation / jnp.linalg.norm(orientation) + self.data.model_state.base_quaternion = jnp.array(unit_quaternion, dtype=float) + @functools.partial(oop.jax_tf.method_rw) def reset_base_transform(self, transform: jtp.Matrix) -> None: + """""" + if transform.shape != (4, 4): raise ValueError(transform.shape) self.reset_base_position(position=transform[0:3, 3]) self.reset_base_orientation(orientation=transform[0:3, 0:3], dcm=True) + @functools.partial(oop.jax_tf.method_rw) def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None: + """""" + if not self.physics_model.is_floating_base: msg = "Changing the base velocity of a fixed-based model is not allowed" raise RuntimeError(msg) @@ -1107,16 +1359,31 @@ def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None: # Integration # =========== + @functools.partial( + oop.jax_tf.method_rw, + static_argnames=["sub_steps", "integrator_type", "terrain"], + vmap_in_axes=(0, 0, 0, None, None, None, 0, None), + ) def integrate( self, t0: jtp.Float, tf: jtp.Float, sub_steps: int = 1, - integrator_type: IntegratorType = IntegratorType.EulerForward, + integrator_type: Optional[ + "jaxsim.simulation.ode_integration.IntegratorType" + ] = None, terrain: soft_contacts.Terrain = soft_contacts.FlatTerrain(), contact_parameters: soft_contacts.SoftContactsParams = soft_contacts.SoftContactsParams(), clear_inputs: bool = False, ) -> StepData: + """""" + + from jaxsim.simulation import ode_data, ode_integration + from jaxsim.simulation.ode_integration import IntegratorType + + if integrator_type is None: + integrator_type = IntegratorType.EulerForward + x0 = ode_integration.ode.ode_data.ODEState( physics_model=self.data.model_state, soft_contacts=self.data.contact_state, @@ -1189,7 +1456,6 @@ def integrate( contact_state=tf_contact_state, model_input=model_input, ) - self._set_mutability(self._mutability()) return StepData( t0=t0, @@ -1217,9 +1483,12 @@ def integrate( # Private methods # =============== + @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"]) def inertial_to_active_representation( self, array: jtp.Array, is_force: bool = False ) -> jtp.Array: + """""" + W_array = array.squeeze() if W_array.size != 6: @@ -1257,9 +1526,12 @@ def inertial_to_active_representation( else: raise ValueError(self.velocity_representation) + @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"]) def active_to_inertial_representation( self, array: jtp.Array, is_force: bool = False ) -> jtp.Array: + """""" + array = array.squeeze() if array.size != 6: @@ -1300,11 +1572,13 @@ def active_to_inertial_representation( else: raise ValueError(self.velocity_representation) - def _joint_indices(self, joint_names: List[str] = None) -> jtp.Vector: + def _joint_indices(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: + """""" + if joint_names is None: joint_names = self.joint_names() - if set(joint_names) - set(self._joints.keys()) != set(): + if set(joint_names) - set(self.joint_names()) != set(): raise ValueError("One or more joint names are not part of the model") # Note: joints share the same index as their child link, therefore the first @@ -1314,4 +1588,4 @@ def _joint_indices(self, joint_names: List[str] = None) -> jtp.Vector: j.joint_description.index - 1 for j in self.joints(joint_names=joint_names) ] - return np.array(joint_indices) + return np.array(joint_indices, dtype=int) diff --git a/src/jaxsim/math/adjoint.py b/src/jaxsim/math/adjoint.py index cc5748195..baa0bf46a 100644 --- a/src/jaxsim/math/adjoint.py +++ b/src/jaxsim/math/adjoint.py @@ -38,17 +38,17 @@ def from_rotation_and_translation( A_o_B = translation.squeeze() if not inverse: - X = A_X_B = jnp.block( + X = A_X_B = jnp.vstack( [ - [A_R_B, Skew.wedge(A_o_B) @ A_R_B], - [jnp.zeros(shape=(3, 3)), A_R_B], + jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]), + jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]), ] ) else: - X = B_X_A = jnp.block( + X = B_X_A = jnp.vstack( [ - [A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)], - [jnp.zeros(shape=(3, 3)), A_R_B.T], + jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]), + jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]), ] ) @@ -62,10 +62,10 @@ def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix: R = X[0:3, 0:3] o_x_R = X[0:3, 3:6] - H = jnp.block( + H = jnp.vstack( [ - [R, Skew.vee(matrix=o_x_R @ R.T)], - [0, 0, 0, 1], + jnp.block([R, Skew.vee(matrix=o_x_R @ R.T)]), + jnp.array([0, 0, 0, 1]), ] ) diff --git a/src/jaxsim/math/cross.py b/src/jaxsim/math/cross.py index bb8dc76a9..ff4195300 100644 --- a/src/jaxsim/math/cross.py +++ b/src/jaxsim/math/cross.py @@ -10,10 +10,10 @@ class Cross: def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix: v, ω = jnp.split(velocity_sixd.squeeze(), 2) - v_cross = jnp.block( + v_cross = jnp.vstack( [ - [Skew.wedge(vector=ω), Skew.wedge(vector=v)], - [jnp.zeros(shape=(3, 3)), Skew.wedge(vector=ω)], + jnp.block([Skew.wedge(vector=ω), Skew.wedge(vector=v)]), + jnp.block([jnp.zeros(shape=(3, 3)), Skew.wedge(vector=ω)]), ] ) diff --git a/src/jaxsim/math/inertia.py b/src/jaxsim/math/inertia.py index ceb9d690c..8505b4b3c 100644 --- a/src/jaxsim/math/inertia.py +++ b/src/jaxsim/math/inertia.py @@ -15,10 +15,10 @@ def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix: c = Skew.wedge(vector=com) - M = jnp.block( + M = jnp.vstack( [ - [mass * jnp.eye(3), mass * c.T], - [mass * c, I + mass * c @ c.T], + jnp.block([mass * jnp.eye(3), mass * c.T]), + jnp.block([mass * c, I + mass * c @ c.T]), ] ) diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 69163059f..56d365463 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -33,7 +33,7 @@ def derivative( omega_in_body_fixed: bool = False, K: float = 0.1, ) -> jtp.Vector: - w = omega.squeeze() + ω = omega.squeeze() quaternion = quaternion.squeeze() def Q_body(q: jtp.Vector) -> jtp.Matrix: @@ -67,10 +67,20 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix: operand=quaternion, ) + norm_ω = jax.lax.cond( + pred=ω.dot(ω) < (1e-6) ** 2, + true_fun=lambda _: 1e-6, + false_fun=lambda _: jnp.linalg.norm(ω), + operand=None, + ) + qd = 0.5 * ( Q @ jnp.hstack( - [K * jnp.linalg.norm(w) * (1 - jnp.linalg.norm(quaternion)), w] + [ + K * norm_ω * (1 - jnp.linalg.norm(quaternion)), + ω, + ] ) ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index bd0ed95d6..e1622bbdc 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -2,9 +2,12 @@ import enum from typing import Optional, Tuple, Union +import jax_dataclasses import numpy as np import numpy.typing as npt +from jaxsim.utils import JaxsimDataclass, Mutability + from .link import LinkDescription @@ -47,12 +50,16 @@ def __hash__(self) -> int: return hash(self.__repr__()) -@dataclasses.dataclass -class JointDescription: - name: str +@jax_dataclasses.pytree_dataclass +class JointDescription(JaxsimDataclass): + """ + In-memory description of a robot link. + """ + + name: jax_dataclasses.Static[str] axis: npt.NDArray pose: npt.NDArray - jtype: Union[JointType, JointDescriptor] + jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]] child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) @@ -69,8 +76,11 @@ class JointDescription: def __post_init__(self): if self.axis is not None: - norm_of_axis = np.linalg.norm(self.axis) - self.axis = self.axis / norm_of_axis + with self.mutable_context( + mutability=Mutability.MUTABLE, restore_after_exception=False + ): + norm_of_axis = np.linalg.norm(self.axis) + self.axis = self.axis / norm_of_axis def __hash__(self) -> int: return hash(self.__repr__()) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 93be26922..57472f1bf 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -48,8 +48,9 @@ def lump_with( I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l # Create the new combined link - lumped_link = copy.deepcopy(self) - lumped_link.mass = self.mass + link.mass - lumped_link.inertia = self.inertia + I_removed_in_lumped_frame + lumped_link = self.replace( + mass=self.mass + link.mass, + inertia=self.inertia + I_removed_in_lumped_frame, + ) return lumped_link diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index fce0b4af2..1cae72f2b 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -17,6 +17,7 @@ import numpy.typing as npt from jaxsim import logging +from jaxsim.utils import Mutability from . import descriptions @@ -82,7 +83,8 @@ def __post_init__(self): # Number joints so that their index matches their child link index links_dict = {l.name: l for l in iter(self)} for joint in self.joints: - joint.index = links_dict[joint.child.name].index + with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + joint.index = links_dict[joint.child.name].index # Check that joint indices are unique assert len([j.index for j in self.joints]) == len( @@ -299,12 +301,13 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": for joint in joints_with_removed_parent_link: # Update the pose. Note that after the lumping process, the dict entry # links_dict[joint.parent.name] contains the final lumped link - joint.pose = full_graph.relative_transform( - relative_to=links_dict[joint.parent.name].name, name=joint.name - ) - - # Update the parent link - joint.parent = links_dict[joint.parent.name] + with joint.mutable_context(mutability=Mutability.MUTABLE): + joint.pose = full_graph.relative_transform( + relative_to=links_dict[joint.parent.name].name, name=joint.name + ) + with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + # Update the parent link + joint.parent = links_dict[joint.parent.name] # =================================================================== # 3. Create the reduced graph considering the removed links as frames diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index 05815aa7a..07c81475d 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -229,6 +229,10 @@ def contact_model( m = tangential_deformation.squeeze() ṁ = jnp.zeros_like(m) + # Note: all the small hardcoded tolerances in this method have been introduced + # to allow jax differentiating through this algorithm. They should not affect + # the accuracy of the simulation, although they might make it less readable. + # ======================== # Normal force computation # ======================== @@ -249,7 +253,11 @@ def contact_model( # Non-linear spring-damper model. # This is the force magnitude along the direction normal to the terrain. - force_normal_mag = jnp.sqrt(δ) * (K * δ + D * δ̇) + force_normal_mag = jax.lax.select( + pred=δ >= 1e-9, + on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇), + on_false=jnp.array(0.0), + ) # Prevent negative normal forces that might occur when δ̇ is largely negative force_normal_mag = jnp.maximum(0.0, force_normal_mag) @@ -263,10 +271,10 @@ def contact_model( # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial. # Note: this is equal to the 6D velocities transform: CW_X_W.transpose(). - W_Xf_CW = jnp.block( + W_Xf_CW = jnp.vstack( [ - [jnp.eye(3), jnp.zeros(shape=(3, 3))], - [Skew.wedge(W_p_C), jnp.eye(3)], + jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]), + jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]), ] ) @@ -304,7 +312,7 @@ def below_terrain(): v_tangential = W_ṗ_C - v_normal # Compute the tangential force. If inside the friction cone, the contact - f_tangential = -jnp.sqrt(δ) * (K * m + D * v_tangential) + f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential) def sticking_contact(): # Sum the normal and tangential forces, and create the 6D force @@ -319,9 +327,17 @@ def sticking_contact(): return CW_f, ṁ def slipping_contact(): + # Clip the tangential force if too small, allowing jax to + # differentiate through the norm computation + f_tangential_no_nan = jax.lax.select( + pred=f_tangential.dot(f_tangential) >= 1e-9**2, + on_true=f_tangential, + on_false=jnp.array([1e-12, 0, 0]), + ) + # Project the force to the friction cone boundary f_tangential_projected = (μ * force_normal_mag) * ( - f_tangential / jnp.linalg.norm(f_tangential) + f_tangential / jnp.linalg.norm(f_tangential_no_nan) ) # Sum the normal and tangential forces, and create the 6D force @@ -331,18 +347,18 @@ def slipping_contact(): # Correct the material deformation derivative for slipping contacts. # Basically we compute ṁ such that we get `f_tangential` on the cone # given the current (m, δ). - ε = 1e-6 - α = -K * jnp.sqrt(δ) + ε = 1e-9 δε = jnp.maximum(δ, ε) - βε = -D * jnp.sqrt(δε) - ṁ = (f_tangential_projected - α * m) / βε + α = -K * jnp.sqrt(δε) + β = -D * jnp.sqrt(δε) + ṁ = (f_tangential_projected - α * m) / β # Return the 6D force in the contact frame and # the deformation derivative return CW_f, ṁ CW_f, ṁ = jax.lax.cond( - pred=jnp.linalg.norm(f_tangential) > μ * force_normal_mag, + pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2, true_fun=lambda _: slipping_contact(), false_fun=lambda _: sticking_contact(), operand=None, diff --git a/src/jaxsim/physics/algos/terrain.py b/src/jaxsim/physics/algos/terrain.py index d3316f24c..fc49576fd 100644 --- a/src/jaxsim/physics/algos/terrain.py +++ b/src/jaxsim/physics/algos/terrain.py @@ -47,4 +47,4 @@ def build(plane_normal: jtp.Vector) -> "PlaneTerrain": def height(self, x: float, y: float) -> float: a, b, c = self.plane_normal - return -(a * x + b * x) / c + return -(a * x + b * y) / c diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index c6dcbddc1..7eb4dbcc3 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -100,20 +100,22 @@ def build_from( # Dicts from the joint index to the static and viscous friction. # Note: the joint index is equal to its child link index. joint_friction_static = { - joint.index: joint.friction_static for joint in model_description.joints + joint.index: jnp.array(joint.friction_static, dtype=float) + for joint in model_description.joints } joint_friction_viscous = { - joint.index: joint.friction_viscous for joint in model_description.joints + joint.index: jnp.array(joint.friction_viscous, dtype=float) + for joint in model_description.joints } # Dicts from the joint index to the spring and damper joint limits parameters. # Note: the joint index is equal to its child link index. joint_limit_spring = { - joint.index: joint.position_limit_spring + joint.index: jnp.array(joint.position_limit_spring, dtype=float) for joint in model_description.joints } joint_limit_damper = { - joint.index: joint.position_limit_damper + joint.index: jnp.array(joint.position_limit_damper, dtype=float) for joint in model_description.joints } diff --git a/src/jaxsim/simulation/__init__.py b/src/jaxsim/simulation/__init__.py index af790bf21..a56ae20fa 100644 --- a/src/jaxsim/simulation/__init__.py +++ b/src/jaxsim/simulation/__init__.py @@ -1 +1,4 @@ -from . import integrators, ode, ode_data +from . import integrators, ode, ode_data, simulator +from .ode_data import ODEInput, ODEState +from .ode_integration import IntegratorType +from .simulator import JaxSim, SimulatorData diff --git a/src/jaxsim/simulation/simulator.py b/src/jaxsim/simulation/simulator.py index 034b96d23..fa96c3650 100644 --- a/src/jaxsim/simulation/simulator.py +++ b/src/jaxsim/simulation/simulator.py @@ -1,7 +1,12 @@ import dataclasses import functools import pathlib -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union + +try: + from typing import Self +except ImportError: + from typing_extensions import Self import jax import jax.numpy as jnp @@ -12,7 +17,6 @@ import jaxsim.high_level import jaxsim.parsers.descriptions as descriptions import jaxsim.physics -import jaxsim.simulation.simulator_callbacks as scb import jaxsim.typing as jtp from jaxsim import logging from jaxsim.high_level.common import VelRepr @@ -20,12 +24,14 @@ from jaxsim.physics.algos.soft_contacts import SoftContactsParams from jaxsim.physics.algos.terrain import FlatTerrain, Terrain from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.simulation import ode_integration -from jaxsim.utils import JaxsimDataclass +from jaxsim.utils import Mutability, Vmappable, oop + +from . import simulator_callbacks as scb +from .ode_integration import IntegratorType @jax_dataclasses.pytree_dataclass -class SimulatorData(JaxsimDataclass): +class SimulatorData(Vmappable): """ Data used by the simulator. @@ -53,11 +59,11 @@ class SimulatorData(JaxsimDataclass): @jax_dataclasses.pytree_dataclass -class JaxSim(JaxsimDataclass): +class JaxSim(Vmappable): """The JaxSim simulator.""" # Step size stored in ns in order to prevent floats approximation - step_size_ns: jtp.Int = dataclasses.field( + step_size_ns: Static[jtp.Int] = dataclasses.field( default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64) ) @@ -71,8 +77,8 @@ class JaxSim(JaxsimDataclass): ) # Integrator type - integrator_type: Static[ode_integration.IntegratorType] = dataclasses.field( - default=ode_integration.IntegratorType.EulerForward + integrator_type: Static[IntegratorType] = dataclasses.field( + default=IntegratorType.EulerForward ) # Simulator data @@ -83,7 +89,7 @@ def build( step_size: jtp.Float, steps_per_run: jtp.Int = 1, velocity_representation: VelRepr = VelRepr.Inertial, - integrator_type: ode_integration.IntegratorType = ode_integration.IntegratorType.EulerSemiImplicit, + integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit, simulator_data: SimulatorData = None, ) -> "JaxSim": """ @@ -108,6 +114,9 @@ def build( data=simulator_data if simulator_data is not None else SimulatorData(), ) + @functools.partial( + oop.jax_tf.method_rw, static_argnames=["remove_models"], validate=False + ) def reset(self, remove_models: bool = True) -> None: """ Reset the simulator. @@ -124,6 +133,7 @@ def reset(self, remove_models: bool = True) -> None: else: _ = [m.zero() for m in self.models()] + @functools.partial(oop.jax_tf.method_rw, jit=False) def set_step_size(self, step_size: float) -> None: """ Set the integration step size. @@ -134,6 +144,18 @@ def set_step_size(self, step_size: float) -> None: self.step_size_ns = jnp.array(step_size * 1e9, dtype=jnp.uint64) + @functools.partial(oop.jax_tf.method_ro, jit=False) + def step_size(self) -> jtp.Float: + """ + Get the integration step size. + + Returns: + The integration step size in seconds. + """ + + return jnp.array(self.step_size_ns / 1e9, dtype=float) + + @functools.partial(oop.jax_tf.method_ro) def dt(self) -> jtp.Float: """ Return the integration step size in seconds. @@ -142,8 +164,9 @@ def dt(self) -> jtp.Float: The integration step size in seconds. """ - return (self.step_size_ns * self.steps_per_run) / 1e9 + return jnp.array((self.step_size_ns * self.steps_per_run) / 1e9, dtype=float) + @functools.partial(oop.jax_tf.method_ro) def time(self) -> jtp.Float: """ Return the current simulation time in seconds. @@ -152,8 +175,9 @@ def time(self) -> jtp.Float: The current simulation time in seconds. """ - return self.data.time_ns / 1e9 + return jnp.array(self.data.time_ns / 1e9, dtype=float) + @functools.partial(oop.jax_tf.method_ro) def gravity(self) -> jtp.Vector: """ Return the 3D gravity vector. @@ -162,9 +186,10 @@ def gravity(self) -> jtp.Vector: The 3D gravity vector. """ - return self.data.gravity + return jnp.array(self.data.gravity, dtype=float) - def model_names(self) -> List[str]: + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def model_names(self) -> tuple[str, ...]: """ Return the list of model names. @@ -172,8 +197,11 @@ def model_names(self) -> List[str]: The list of model names. """ - return list(self.data.models.keys()) + return tuple(self.data.models.keys()) + @functools.partial( + oop.jax_tf.method_ro, static_argnames=["model_name"], jit=False, vmap=False + ) def get_model(self, model_name: str) -> Model: """ Return the model with the given name. @@ -190,7 +218,8 @@ def get_model(self, model_name: str) -> Model: return self.data.models[model_name] - def models(self, model_names: List[str] = None) -> List[Model]: + @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) + def models(self, model_names: tuple[str, ...] = None) -> tuple[Model, ...]: """ Return the simulated models. @@ -203,8 +232,9 @@ def models(self, model_names: List[str] = None) -> List[Model]: """ model_names = model_names if model_names is not None else self.model_names() - return [self.data.models[name] for name in model_names] + return tuple(self.data.models[name] for name in model_names) + @functools.partial(oop.jax_tf.method_rw) def set_gravity(self, gravity: jtp.Vector) -> None: """ Set the gravity vector to all the simulated models. @@ -213,7 +243,7 @@ def set_gravity(self, gravity: jtp.Vector) -> None: gravity: The 3D gravity vector. """ - gravity = jnp.array(gravity) + gravity = jnp.array(gravity, dtype=float) if gravity.size != 3: raise ValueError(gravity) @@ -223,13 +253,12 @@ def set_gravity(self, gravity: jtp.Vector) -> None: for model_name, model in self.data.models.items(): model.physics_model.set_gravity(gravity=gravity) - self._set_mutability(self._mutability()) - + @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) def insert_model_from_description( self, model_description: Union[pathlib.Path, str, rod.Model], model_name: Optional[str] = None, - considered_joints: Optional[List[str]] = None, + considered_joints: List[str] = None, ) -> Model: """ Insert a model from a model description. @@ -245,6 +274,9 @@ def insert_model_from_description( The newly inserted model. """ + if self.vectorized: + raise RuntimeError("Cannot insert a model in a vectorized simulation") + # Build the model from the given model description model = jaxsim.high_level.model.Model.build_from_model_description( model_description=model_description, @@ -261,13 +293,10 @@ def insert_model_from_description( # Insert the model self.data.models[model.name()] = model - # Propagate the current mutability property to make sure that also the - # newly inserted model matches the mutability of the simulator - self._set_mutability(self._mutability()) - # Return the newly inserted model return self.data.models[model.name()] + @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) def insert_model_from_sdf( self, sdf: Union[pathlib.Path, str], @@ -289,6 +318,7 @@ def insert_model_from_sdf( considered_joints=considered_joints, ) + @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) def insert_model( self, model_description: descriptions.ModelDescription, model_name: str = None ) -> Model: @@ -303,6 +333,9 @@ def insert_model( The newly inserted model. """ + if self.vectorized: + raise RuntimeError("Cannot insert a model in a vectorized simulation") + model_name = model_name if model_name is not None else model_description.name if model_name in self.model_names(): @@ -323,11 +356,16 @@ def insert_model( # Insert the model into the simulators self.data.models[model.name()] = model - self._set_mutability(self._mutability()) # Return the newly inserted model return self.data.models[model.name()] + @functools.partial( + oop.jax_tf.method_rw, + jit=False, + validate=False, + static_argnames=["model_name"], + ) def remove_model(self, model_name: str) -> None: """ Remove a model from the simulator. @@ -341,8 +379,8 @@ def remove_model(self, model_name: str) -> None: raise ValueError(msg) _ = self.data.models.pop(model_name) - self._set_mutability(self._mutability()) + @functools.partial(oop.jax_tf.method_rw, vmap_in_axes=(0, None)) def step(self, clear_inputs: bool = False) -> Dict[str, StepData]: """ Advance the simulation by one step. @@ -384,16 +422,22 @@ def step(self, clear_inputs: bool = False) -> Dict[str, StepData]: # Store the final time self.data.time_ns += dt_ns - self._set_mutability(self._mutability()) return step_data - @functools.partial(jax.jit, static_argnames=["horizon_steps"]) + @functools.partial( + oop.jax_tf.method_ro, + static_argnames=["horizon_steps"], + vmap_in_axes=(0, None, 0, None), + ) def step_over_horizon( self, horizon_steps: jtp.Int, callback_handler: Union["scb.SimulatorCallback", "scb.CallbackHandler"] = None, clear_inputs: jtp.Bool = False, - ) -> Union["JaxSim", Tuple["JaxSim", Tuple["scb.SimulatorCallback", jtp.PyTree]]]: + ) -> Union[ + "JaxSim", + tuple["JaxSim", tuple["scb.SimulatorCallback", tuple[jtp.PyTree, jtp.PyTree]]], + ]: """ Advance the simulation by a given number of steps. @@ -404,8 +448,9 @@ def step_over_horizon( Returns: The updated simulator if no callback handler is provided, otherwise a tuple - containing the updated simulator and a tuple with the updated callback object - and the optional output it produced. + containing the updated simulator and a tuple containing callback data. + The optional callback data is a tuple containing the updated callback object, + the produced pre-step output, and the produced post-step output. """ # Process a mutable copy of the simulator @@ -434,18 +479,21 @@ def step_over_horizon( sim = configure_cb(sim) if configure_cb is not None else sim # Initialize the carry - Carry = Tuple[JaxSim, scb.CallbackHandler] + Carry = tuple[JaxSim, scb.CallbackHandler] carry_init: Carry = (sim, callback_handler) - def body_fun(carry: Carry, xs: None) -> Tuple[Carry, jtp.PyTree]: + def body_fun( + carry: Carry, xs: None + ) -> tuple[Carry, tuple[jtp.PyTree, jtp.PyTree]]: sim, callback_handler = carry # Make sure to pass a mutable version of the simulator to the callbacks sim = sim.mutable(validate=True) # Callback: pre-step - # TODO: should we allow also producing a pre-step output? - sim = pre_step_cb(sim) if pre_step_cb is not None else sim + sim, out_pre_step = ( + pre_step_cb(sim) if pre_step_cb is not None else (sim, None) + ) # Integrate all models step_data = sim.step(clear_inputs=clear_inputs) @@ -460,12 +508,13 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, jtp.PyTree]: # Pack the carry carry = (sim, callback_handler) - return carry, out_post_step + return carry, (out_pre_step, out_post_step) # Integrate over the given horizon - (sim, callback_handler), out_cb_horizon = jax.lax.scan( - f=body_fun, init=carry_init, xs=None, length=horizon_steps - ) + (sim, callback_handler), ( + out_pre_step_horizon, + out_post_step_horizon, + ) = jax.lax.scan(f=body_fun, init=carry_init, xs=None, length=horizon_steps) # Enforce original mutability of the entire simulator sim._set_mutability(original_mutability) @@ -473,5 +522,22 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, jtp.PyTree]: return ( sim if callback_handler is None - else (sim, (callback_handler, out_cb_horizon)) + else ( + sim, + (callback_handler, (out_pre_step_horizon, out_post_step_horizon)), + ) ) + + def vectorize(self: Self, batch_size: int) -> Self: + """ + Inherit docs. + """ + + jaxsim_vec: JaxSim = super().vectorize(batch_size=batch_size) # noqa + + # We need to manually specify the batch size of the handled models + with jaxsim_vec.mutable_context(mutability=Mutability.MUTABLE): + for model in jaxsim_vec.models(): + model.batch_size = batch_size + + return jaxsim_vec diff --git a/src/jaxsim/simulation/simulator_callbacks.py b/src/jaxsim/simulation/simulator_callbacks.py index e288d8f89..00208474d 100644 --- a/src/jaxsim/simulation/simulator_callbacks.py +++ b/src/jaxsim/simulation/simulator_callbacks.py @@ -1,12 +1,14 @@ import abc from typing import Callable, Dict, Tuple -import jaxsim +import jaxsim.high_level.model import jaxsim.typing as jtp from jaxsim.high_level.model import StepData ConfigureCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"] -PreStepCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"] +PreStepCallbackSignature = Callable[ + ["jaxsim.JaxSim"], Tuple["jaxsim.JaxSim", jtp.PyTree] +] PostStepCallbackSignature = Callable[ ["jaxsim.JaxSim", Dict[str, StepData]], Tuple["jaxsim.JaxSim", jtp.PyTree] ] @@ -32,7 +34,7 @@ def pre_step_cb(self) -> PreStepCallbackSignature: return lambda sim: self.pre_step(sim=sim) @abc.abstractmethod - def pre_step(self, sim: "jaxsim.JaxSim") -> "jaxsim.JaxSim": + def pre_step(self, sim: "jaxsim.JaxSim") -> Tuple["jaxsim.JaxSim", jtp.PyTree]: pass diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 82355cf33..94b9508be 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Hashable, List, NamedTuple, Tuple, Union import jax.numpy as jnp -import numpy as np import numpy.typing as npt # JAX types @@ -35,6 +34,6 @@ Tensor = Union[npt.NDArray, ArrayJax] Vector = Array Matrix = Array -Bool = bool +Bool = Union[bool, ArrayJax] Int = Union[int, IntJax] Float = Union[float, FloatJax] diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py new file mode 100644 index 000000000..b79fd990f --- /dev/null +++ b/src/jaxsim/utils/__init__.py @@ -0,0 +1,8 @@ +from jax_dataclasses._copy_and_mutate import _Mutability as Mutability + +from .jaxsim_dataclass import JaxsimDataclass +from .tracing import not_tracing, tracing +from .vmappable import Vmappable + +# Leave this below the others to prevent circular imports +from .oop import jax_tf # isort: skip diff --git a/src/jaxsim/utils.py b/src/jaxsim/utils/jaxsim_dataclass.py similarity index 54% rename from src/jaxsim/utils.py rename to src/jaxsim/utils/jaxsim_dataclass.py index 07a6bef5e..c719dfb3c 100644 --- a/src/jaxsim/utils.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -1,44 +1,31 @@ import abc import contextlib import copy -from typing import Any, ContextManager, TypeVar +import dataclasses +from typing import ClassVar, Generator -import jax.abstract_arrays import jax.flatten_util -import jax.interpreters.partial_eval import jax_dataclasses -from jax_dataclasses._copy_and_mutate import _Mutability as Mutability import jaxsim.typing as jtp -T = TypeVar("T") +from . import Mutability - -def tracing(var: Any) -> bool: - """Returns True if the variable is being traced by JAX, False otherwise.""" - - return jax.numpy.array( - [ - isinstance(var, t) - for t in ( - jax.abstract_arrays.ShapedArray, - jax.interpreters.partial_eval.DynamicJaxprTracer, - ) - ] - ).any() - - -def not_tracing(var: Any) -> bool: - """Returns True if the variable is not being traced by JAX, False otherwise.""" - - return True if tracing(var) is False else False +try: + from typing import Self +except ImportError: + from typing_extensions import Self +@jax_dataclasses.pytree_dataclass class JaxsimDataclass(abc.ABC): """""" + # This attribute is set by jax_dataclasses + __mutability__: ClassVar[Mutability] = Mutability.FROZEN + @contextlib.contextmanager - def editable(self: T, validate: bool = True) -> ContextManager[T]: + def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]: """""" mutability = ( @@ -48,23 +35,34 @@ def editable(self: T, validate: bool = True) -> ContextManager[T]: with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj: yield obj - # with jax_dataclasses.copy_and_mutate(self, validate=validate) as self_rw: - # yield self_rw - # - # self_rw._set_mutability(self._mutability()) - @contextlib.contextmanager - def mutable_context(self: T, mutability: Mutability) -> ContextManager[T]: + def mutable_context( + self: Self, mutability: Mutability, restore_after_exception: bool = True + ) -> Generator[Self, None, None]: """""" - original_mutability = self._mutability() + if restore_after_exception: + self_copy = self.copy() - self._set_mutability(mutability) - yield self - - self._set_mutability(original_mutability) + original_mutability = self._mutability() - def is_mutable(self: T, validate: bool = False) -> bool: + def restore_self(): + self._set_mutability(mutability=Mutability.MUTABLE) + for f in dataclasses.fields(self_copy): + setattr(self, f.name, getattr(self_copy, f.name)) + + try: + self._set_mutability(mutability) + yield self + except Exception as e: + if restore_after_exception: + restore_self() + self._set_mutability(original_mutability) + raise e + finally: + self._set_mutability(original_mutability) + + def is_mutable(self, validate: bool = False) -> bool: """""" return ( @@ -91,21 +89,21 @@ def _set_mutability(self, mutability: Mutability) -> None: self, mutable=mutability, visited=set() ) - def mutable(self: T, mutable: bool = True, validate: bool = False) -> T: + def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: self.set_mutability(mutable=mutable, validate=validate) return self - def copy(self: T) -> T: - obj = copy.deepcopy(self) + def copy(self: Self) -> Self: + obj = jax.tree_util.tree_map(lambda leaf: leaf, self) obj._set_mutability(mutability=self._mutability()) return obj - def replace(self: T, validate: bool = True, **kwargs) -> T: + def replace(self: Self, validate: bool = True, **kwargs) -> Self: with self.editable(validate=validate) as obj: _ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()] obj._set_mutability(mutability=self._mutability()) return obj - def flatten(self: T) -> jtp.VectorJax: + def flatten(self) -> jtp.VectorJax: return jax.flatten_util.ravel_pytree(self)[0] diff --git a/src/jaxsim/utils/oop.py b/src/jaxsim/utils/oop.py new file mode 100644 index 000000000..cbca1ae5f --- /dev/null +++ b/src/jaxsim/utils/oop.py @@ -0,0 +1,530 @@ +import contextlib +import dataclasses +import functools +import inspect +import os +from typing import Any, Callable, Generator + +import jax +import jax.flatten_util + +from jaxsim import logging +from jaxsim.utils import tracing + +from . import Mutability, Vmappable + + +class jax_tf: + """ + Class containing decorators applicable to methods of Vmappable objects. + """ + + # Environment variables that can be used to disable the transformations + EnvVarOOP: str = "JAXSIM_OOP_DECORATORS" + EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT" + EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP" + EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE" + + @staticmethod + def method_ro( + fn: Callable, + jit: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for r/o methods of classes inheriting from Vmappable. + """ + + return jax_tf.method( + fn=fn, + read_only=True, + validate=True, + jit_enabled=jit, + static_argnames=static_argnames, + vmap_enabled=vmap, + vmap_in_axes=vmap_in_axes, + vmap_out_axes=vmap_out_axes, + ) + + @staticmethod + def method_rw( + fn: Callable, + validate: bool = True, + jit: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for r/w methods of classes inheriting from Vmappable. + """ + + return jax_tf.method( + fn=fn, + read_only=False, + validate=validate, + jit_enabled=jit, + static_argnames=static_argnames, + vmap_enabled=vmap, + vmap_in_axes=vmap_in_axes, + vmap_out_axes=vmap_out_axes, + ) + + @staticmethod + def method( + fn: Callable, + read_only: bool = True, + validate: bool = True, + jit_enabled: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap_enabled: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for methods of classes inheriting from Vmappable. + + This decorator enables executing the methods on an object characterized by a + desired mutability, that is selected considering the r/o and validation flags. + It also allows to transform the method with the jit/vmap transformations. + If the Vmappable object is vectorized, the method is automatically vmapped, and + the in_axes are properly post-processed to simplify the combination with jit. + + Args: + fn: The method to decorate. + read_only: Whether the method operates on a read-only object. + validate: Whether r/w methods should preserve the pytree structure. + jit_enabled: Whether to apply the jit transformation. + static_argnames: The names of the arguments that should be static. + vmap_enabled: Whether to apply the vmap transformation. + vmap_in_axes: The in_axes to use for the vmap transformation. + vmap_out_axes: The out_axes to use for the vmap transformation. + + Returns: + The decorated method. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + """The wrapper function that is returned by this decorator.""" + + # Methods of classes inheriting from Vmappable decorated by this wrapper + # automatically support jit/vmap/mutability features when called standalone. + # However, when objects are arguments of plain functions transformed with + # jit/vmap, and decorated methods are called inside those functions, we need + # to disable this decorator to avoid double wrapping and execution errors. + # We do so by iterating over the arguments, and checking whether they are + # being traced by JAX. + for argument in list(args) + list(kwargs.values()): + try: + argument_flat, _ = jax.flatten_util.ravel_pytree(argument) + + if tracing(argument_flat): + return fn(*args, **kwargs) + except: + continue + + # =============================================================== + # Wrap fn so that jit/vmap/mutability transformations are applied + # =============================================================== + + # Initialize the mutability of the instance over which the method is running. + # * In r/o methods, this approach prevents any type of mutation. + # * In r/w methods, this approach allows to catch early JIT recompilations + # caused by unwanted changes in the pytree structure. + if read_only: + mutability = Mutability.FROZEN + else: + mutability = ( + Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION + ) + + # Extract the class instance over which fn is called + instance: Vmappable = args[0] + assert isinstance(instance, Vmappable) + + # Save the original mutability + original_mutability = instance._mutability() + + # Inspect the environment to detect whether to enforce disabling jit/vmap + deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP) + jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on + vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on + + # Allow disabling the cache of jit-compiled functions. + # It can be useful for debugging or testing purposes. + wrap_fn = ( + jax_tf.wrap_fn + if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on + else jax_tf.wrap_fn.__wrapped__ + ) + + # Get the transformed function (possibly cached by functools.cache). + # Note that all the arguments of the following methods, when hashed, should + # uniquely identify the returned function so that a new function is built + # when arguments change and either jit or vmap have to be called again. + fn_db = wrap_fn( + fn=fn, # noqa + mutability=mutability, + jit=jit_enabled_env and jit_enabled, + static_argnames=tuple(static_argnames), + vmap=vmap_enabled_env + and ( + vmap_enabled is True + or (vmap_enabled is None and instance.vectorized) + ), + in_axes=vmap_in_axes, + out_axes=vmap_out_axes, + ) + + # Call the transformed (mutable/jit/vmap) method + out, obj = fn_db(*args, **kwargs) + + if read_only: + # Restore the original mutability + instance._set_mutability(mutability=original_mutability) + + return out + + # ================================================================= + # From here we assume that the wrapper is operating on a r/w method + # ================================================================= + + from jax_dataclasses._dataclasses import JDC_STATIC_MARKER + + # Select the right runtime mutability. The only difference here is when a r/w + # method is called on a frozen object. In this case, we enable updating the + # pytree data and preserve its structure only if validation is enabled. + mutability_dict = { + Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION, + Mutability.MUTABLE: Mutability.MUTABLE, + Mutability.FROZEN: Mutability.MUTABLE + if validate + else Mutability.MUTABLE_NO_VALIDATION, + } + + # We need to replace all the dynamic leafs of the original instance with those + # computed by the functional transformation. + # We do so by iterating over the fields of the jax_dataclasses and ignoring + # all the fields that are marked as static. + # Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121. + with instance.mutable_context( + mutability=mutability_dict[instance._mutability()] + ): + for f in dataclasses.fields(instance): # noqa + if ( + hasattr(f, "type") + and hasattr(f.type, "__metadata__") + and JDC_STATIC_MARKER in f.type.__metadata__ + ): + continue + + try: + setattr(instance, f.name, getattr(obj, f.name)) + except AssertionError: + logging.debug(f"Old object:\n{getattr(instance, f.name)}") + logging.debug(f"New object:\n{getattr(obj, f.name)}") + raise RuntimeError(f"Failed to update field '{f.name}'") + + return out + + return wrapper + + @staticmethod + @functools.cache + def wrap_fn( + fn: Callable, + mutability: Mutability, + jit: bool, + static_argnames: tuple[str, ...] | list[str], + vmap: bool, + in_axes: tuple[int, ...] | int | None, + out_axes: tuple[int, ...] | int | None, + ) -> Callable: + """ + Transform a method with jit/vmap and execute it on an object characterized + by the desired mutability. + + Note: + The method should take the object (self) as first argument. + + Note: + This returned transformed method is cached by considering the hash of all + the arguments. It will re-apply jit/vmap transformations only if needed. + + Args: + fn: The method to consider. + mutability: The mutability of the object on which the method is called. + jit: Whether to apply jit transformations. + static_argnames: The names of the arguments that should be considered static. + vmap: Whether to apply vmap transformations. + in_axes: The axes along which to vmap input arguments. + out_axes: The axes along which to vmap output arguments. + + Note: + In order to simplify the application of vmap, we close the method arguments + over all the non-mapped input arguments. Furthermore, for improving the + compatibility with jit, we also close the vmap application over the static + arguments. + + Returns: + The transformed method operating on an object with the desired mutability. + We maintain the same signature of the original method. + """ + + # Extract the signature of the function + sig = inspect.signature(fn) + + # All static arguments must be actual arguments of fn + for name in static_argnames: + if name not in sig.parameters: + raise ValueError(f"Static argument '{name}' not found in {fn}") + + # If in_axes is a tuple, its dimension should match the number of arguments + if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters): + msg = "The length of 'in_axes' must match the number of arguments ({})" + raise ValueError(msg.format(len(sig.parameters))) + + # Check that static arguments are not mapped with vmap. + # This case would not work since static arguments are not traces and vmap need + # to trace arguments in order to map them. + if isinstance(in_axes, tuple): + for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()): + if mapped_axis is not None and arg_name in static_argnames: + raise ValueError( + f"Static argument '{arg_name}' cannot be mapped with vmap" + ) + + def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs): + """Wrapper applying the vmap transformation""" + + # Canonicalize the arguments so that all of them are kwargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Build a dictionary mapping all arguments to a mapped axis, even when + # the None is passed (defaults to in_axes=0) or and int is passed (defaults + # to in_axes=). + match in_axes: + case None: + argname_to_mapped_axis = {name: 0 for name in bound.arguments} + case tuple(): + argname_to_mapped_axis = { + name: in_axes[i] for i, name in enumerate(bound.arguments) + } + case int(): + argname_to_mapped_axis = {name: in_axes for name in bound.arguments} + case _: + raise ValueError(in_axes) + + # Build a dictionary (argument_name -> argument) for all mapped arguments. + # Note that a mapped argument is an argument whose axis is not None and + # is not a static jit argument. + vmap_mapped_args = { + arg: value + for arg, value in bound.arguments.items() + if argname_to_mapped_axis[arg] is not None + and arg not in static_argnames + } + + # Build a dictionary (argument_name -> argument) for all unmapped arguments + vmap_unmapped_args = { + arg: value + for arg, value in bound.arguments.items() + if arg not in vmap_mapped_args + } + + # Disable mapping of non-vectorized default arguments + for arg, value in argname_to_mapped_axis.items(): + if arg in vmap_mapped_args and value == sig.parameters[arg].default: + logging.debug(f"Disabling vmapping of default argument '{arg}'") + argname_to_mapped_axis[arg] = None + + # Close the function over the unmapped arguments of vmap + fn_closed = lambda *mapped_args: function_to_vmap( + **vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args)) + ) + + # Create the in_axes tuple of only the mapped arguments + in_axes_mapped = tuple( + argname_to_mapped_axis[name] for name in vmap_mapped_args + ) + + # If all in_axes are the same, simplify in_axes tuple to be just an integer + if len(set(in_axes_mapped)) == 1: + in_axes_mapped = list(set(in_axes_mapped))[0] + + # If, instead, in_axes has different elements, we need to replace the mapped + # axis of "self" with a pytree having as leafs the mapped axis. + # This is because the vmap in_axes specification must be a tree prefix of + # the corresponding value. + if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args: + argname_to_mapped_axis["self"] = jax.tree_util.tree_map( + lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"] + ) + in_axes_mapped = tuple( + argname_to_mapped_axis[name] for name in vmap_mapped_args + ) + + # Apply the vmap transformation and call the function passing only the + # mapped arguments. The unmapped arguments have been closed over. + # Note: we altered the "in_axes" tuple so that it does not have any + # None elements. + # Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs, + # we need to pass the unpacked args tuple instead. + return jax.vmap( + fn_closed, + in_axes=in_axes_mapped, + **dict(out_axes=out_axes) if out_axes is not None else {}, + )(*list(vmap_mapped_args.values())) + + def fn_tf_jit(*args, function_to_jit: Callable, **kwargs): + """Wrapper applying the jit transformation""" + + # Canonicalize the arguments so that all of them are kwargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Apply the jit transformation and call the function passing all arguments + # as keyword arguments + return jax.jit(function_to_jit, static_argnames=static_argnames)( + **bound.arguments + ) + + # First applied wrapper that executes fn in a mutable context + fn_mutable = functools.partial( + jax_tf.call_class_method_in_mutable_context, + fn=fn, + jit=jit, + mutability=mutability, + ) + + # Second applied wrapper that transforms fn with vmap + fn_vmap = ( + fn_mutable + if not vmap + else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable) + ) + + # Third applied wrapper that transforms fn with jit + fn_jit_vmap = ( + fn_vmap + if not jit + else functools.partial(fn_tf_jit, function_to_jit=fn_vmap) + ) + + return fn_jit_vmap + + @staticmethod + def call_class_method_in_mutable_context( + *args, fn: Callable, jit: bool, mutability: Mutability, **kwargs + ) -> tuple[Any, Vmappable]: + """ + Wrapper to call a method on an object with the desired mutable context. + + Args: + fn: The method to call. + jit: Whether the method is being jit compiled or not. + mutability: The desired mutability context. + *args: The positional arguments to pass to the method (including self). + **kwargs: The keyword arguments to pass to the method. + + Returns: + A tuple containing the return value of the method and the object + possibly updated by the method if it is in read-write. + + Note: + This approach enables to jit-compile methods of a stateful object without + leaking traces, therefore obtaining a jax-compatible OOP pattern. + """ + + # Log here whether the method is being jit compiled or not. + # This log message does not get printed from compiled code, so here is the + # most appropriate place to be sure that we log it correctly. + if jit: + logging.debug(msg=f"JIT compiling {fn}") + + # Canonicalize the arguments so that all of them are kwargs + sig = inspect.signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Extract the class instance over which fn is called + instance: Vmappable = bound.arguments["self"] + + # Select the right mutability. If the instance is mutable with validation + # disabled, we override the input mutability so that we do not fail in case + # of mismatched tree structure. + mut = ( + Mutability.MUTABLE_NO_VALIDATION + if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION + else mutability + ) + + # Call fn in a mutable context + with instance.mutable_context(mutability=mut): + # Methods could call other decorated methods. When it happens, the decorator + # of the called method is invoked, that applies jit and vmap transformations. + # This is not desired as it calls vmap inside an already vmapped method. + # We work around this occurrence by disabling the jit/vmap decorators of all + # methods called inside fn through a context manager. + # Note that we already work around this in the beginning of the wrapper + # function by detecting traced arguments, but the decorator works also + # when jit=False and vmap=False, therefore only enforcing the mutability. + with jax_tf.disabled_oop_decorators(): + out = fn(**bound.arguments) + + return out, instance + + @staticmethod + def env_var_on(var_name: str, default_value: str = "1") -> bool: + """ + Check whether an environment variable is set to a value that is considered on. + + Args: + var_name: The name of the environment variable. + default_value: The default variable value to consider if the variable has not + been exported. + + Returns: + True if the environment variable contains an on value, False otherwise. + """ + + on_values = {"1", "true", "on", "yes"} + return os.environ.get(var_name, default_value).lower() in on_values + + @staticmethod + @contextlib.contextmanager + def disabled_oop_decorators() -> Generator[None, None, None]: + """ + Context manager to disable the application of jax transformations performed by + the decorators of this class. + + Note: when the transformations are disabled, the only logic still applied is + the selection of the object mutability over which the method is running. + """ + + # Check whether the environment variable is part of the environment and + # save its value. We restore the original value before exiting the context. + env_cache = ( + None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP] + ) + + # Disable both jit and vmap transformations + os.environ[jax_tf.EnvVarOOP] = "0" + + try: + # Execute the code in the context with disabled transformations + yield + + finally: + # Restore the original value of the environment variable or remove it if + # it was not present before entering the context + if env_cache is not None: + os.environ[jax_tf.EnvVarOOP] = env_cache + else: + _ = os.environ.pop(jax_tf.EnvVarOOP) diff --git a/src/jaxsim/utils/tracing.py b/src/jaxsim/utils/tracing.py new file mode 100644 index 000000000..9d40fc0fd --- /dev/null +++ b/src/jaxsim/utils/tracing.py @@ -0,0 +1,25 @@ +from typing import Any + +import jax._src.core +import jax.flatten_util +import jax.interpreters.partial_eval + + +def tracing(var: Any) -> bool | jax.Array: + """Returns True if the variable is being traced by JAX, False otherwise.""" + + return jax.numpy.array( + [ + isinstance(var, t) + for t in ( + jax._src.core.Tracer, + jax.interpreters.partial_eval.DynamicJaxprTracer, + ) + ] + ).any() + + +def not_tracing(var: Any) -> bool: + """Returns True if the variable is not being traced by JAX, False otherwise.""" + + return True if tracing(var) is False else False diff --git a/src/jaxsim/utils/vmappable.py b/src/jaxsim/utils/vmappable.py new file mode 100644 index 000000000..0e449f4b8 --- /dev/null +++ b/src/jaxsim/utils/vmappable.py @@ -0,0 +1,117 @@ +import dataclasses +from typing import Type + +import jax +import jax.numpy as jnp +import jax_dataclasses + +from . import JaxsimDataclass, Mutability + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class Vmappable(JaxsimDataclass): + """Abstract class with utilities for vmappable pytrees.""" + + batch_size: jax_dataclasses.Static[int] = dataclasses.field( + default=int(0), repr=False, compare=False, hash=False, kw_only=True + ) + + @property + def vectorized(self) -> bool: + """Marks this pytree as vectorized.""" + + return self.batch_size > 0 + + @classmethod + def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self: + """ + Build a vectorized pytree from a list of pytree of the same type. + + Args: + list_of_obj: The list of pytrees to vectorize. + + Returns: + The vectorized pytree having as leaves the stacked leaves of the input list. + """ + + if set(type(el) for el in list_of_obj) != {cls}: + msg = "The input list must contain only objects of type '{}'" + raise ValueError(msg.format(cls.__name__)) + + # Create a pytree by stacking all the leafs of the input list + data_vec: Vmappable = jax.tree_map( + lambda *leafs: jnp.array(leafs), *list_of_obj + ) + + # Store the batch dimension + with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + data_vec.batch_size = len(list_of_obj) + + # Detect the most common mutability in the input list + mutabilities = [e._mutability() for e in list_of_obj] + mutability = max(set(mutabilities), key=mutabilities.count) + + # Update the mutability of the vectorized pytree + data_vec._set_mutability(mutability) + + return data_vec + + def vectorize(self: Self, batch_size: int) -> Self: + """ + Return a vectorized version of this pytree. + + Args: + batch_size: The batch size. + + Returns: + A vectorized version of this pytree obtained by stacking the leaves of the + original pytree along a new batch dimension (the first one). + """ + + if self.vectorized: + raise RuntimeError("Cannot vectorize an already vectorized object") + + if batch_size == 0: + return self.copy() + + # TODO validate if mutability is maintained + + return self.__class__.build_from_list(list_of_obj=[self] * batch_size) + + def extract_element(self: Self, index: int) -> Self: + """ + Extract the i-th element from a vectorized pytree. + + Args: + index: The index of the element to extract. + + Returns: + A non vectorized pytree obtained by extracting the i-th element from the + vectorized pytree. + """ + + if index < 0: + raise ValueError("The index of the desired element cannot be negative") + + if index == 0 and self.batch_size == 0: + return self.copy() + + if not self.vectorized: + raise RuntimeError("Cannot extract elements from a non-vectorized object") + + if index >= self.batch_size: + raise ValueError("The index must be smaller than the batch size") + + # Get the i-th pytree by extracting the i-th element from the vectorized pytree + data = jax.tree_map(lambda leaf: leaf[index], self) + + # Update the batch size of the extracted scalar pytree + with data.mutable_context(mutability=Mutability.MUTABLE): + data.batch_size = 0 + + return data diff --git a/tests/test_ad_physics.py b/tests/test_ad_physics.py new file mode 100644 index 000000000..9c7db2c0e --- /dev/null +++ b/tests/test_ad_physics.py @@ -0,0 +1,190 @@ +import jax.numpy as jnp +import numpy as np +import pytest +from jax.test_util import check_grads +from pytest import param as p + +from jaxsim.high_level.common import VelRepr +from jaxsim.high_level.model import Model + +from . import utils_models, utils_rng +from .utils_models import Robot + + +@pytest.mark.parametrize( + "robot, vel_repr", + [ + p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), + p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), + p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), + ], +) +def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None: + """Unit test of the application of Automatic Differentiation on RBD algorithms.""" + + robot = Robot.Ur10 + vel_repr = VelRepr.Inertial + + # Initialize the gravity + gravity = np.array([0, 0, -10.0]) + + # Get the URDF of the robot + urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot) + + # Build the high-level model + model = Model.build_from_model_description( + model_description=urdf_file_path, + vel_repr=vel_repr, + gravity=gravity, + is_urdf=True, + ).mutable(mutable=True, validate=True) + + # Initialize the model with a random state + model.data.model_state = utils_rng.random_physics_model_state( + physics_model=model.physics_model + ) + + # Initialize the model with a random input + model.data.model_input = utils_rng.random_physics_model_input( + physics_model=model.physics_model + ) + + # ======================== + # Extract state and inputs + # ======================== + + # Extract the physics model used in the low-level physics algorithms + physics_model = model.physics_model + + # State + s = model.joint_positions() + ṡ = model.joint_velocities() + xfb = model.data.model_state.xfb() + + # Inputs + f_ext = model.external_forces() + tau = model.joint_generalized_forces_targets() + + # Perturbation used for computing finite differences + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # ===================================================== + # Check first-order and second-order derivatives of ABA + # ===================================================== + + import jaxsim.physics.algos.aba + + aba = lambda xfb, s, ṡ, tau, f_ext: jaxsim.physics.algos.aba.aba( + model=physics_model, xfb=xfb, q=s, qd=ṡ, tau=tau, f_ext=f_ext + ) + + check_grads( + f=aba, + args=(xfb, s, ṡ, tau, f_ext), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ====================================================== + # Check first-order and second-order derivatives of RNEA + # ====================================================== + + import jaxsim.physics.algos.rnea + + W_v̇_WB = utils_rng.get_rng().uniform(size=6, low=-1) + s̈ = utils_rng.get_rng().uniform(size=physics_model.dofs(), low=-1) + + rnea = lambda xfb, s, ṡ, s̈, W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea( + model=physics_model, xfb=xfb, q=s, qd=ṡ, qdd=s̈, a0fb=W_v̇_WB, f_ext=f_ext + ) + + check_grads( + f=rnea, + args=(xfb, s, ṡ, s̈, W_v̇_WB, f_ext), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ====================================================== + # Check first-order and second-order derivatives of CRBA + # ====================================================== + + import jaxsim.physics.algos.crba + + crba = lambda s: jaxsim.physics.algos.crba.crba(model=physics_model, q=s) + + check_grads( + f=crba, + args=(s,), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ==================================================== + # Check first-order and second-order derivatives of FK + # ==================================================== + + import jaxsim.physics.algos.forward_kinematics + + fk = ( + lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( + model=physics_model, xfb=xfb, q=s + ) + ) + + check_grads( + f=fk, + args=(xfb, s), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ========================================================== + # Check first-order and second-order derivatives of Jacobian + # ========================================================== + + import jaxsim.physics.algos.jacobian + + link_indices = [l.index() for l in model.links()] + + jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian( + model=physics_model, q=s, body_index=link_indices[-1] + ) + + check_grads( + f=jacobian, + args=(s,), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ===================================================================== + # Check first-order and second-order derivatives of soft contacts model + # ===================================================================== + + import jaxsim.physics.algos.soft_contacts + + p = utils_rng.get_rng().uniform(size=3, low=-1) + v = utils_rng.get_rng().uniform(size=3, low=-1) + m = utils_rng.get_rng().uniform(size=3, low=-1) + + parameters = jaxsim.physics.algos.soft_contacts.SoftContactsParams.build( + K=10_000, D=20.0, mu=0.5 + ) + + soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts( + parameters=parameters + ).contact_model(position=p, velocity=v, tangential_deformation=m) + + check_grads( + f=soft_contacts, + args=(p, v, m), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) diff --git a/tests/test_eom.py b/tests/test_eom.py index f1b528284..4c0926382 100644 --- a/tests/test_eom.py +++ b/tests/test_eom.py @@ -1,33 +1,32 @@ import pathlib -import jax +import jax.numpy as jnp import numpy as np import pytest +from pytest import param as p from jaxsim.high_level.common import VelRepr from jaxsim.high_level.model import Model from . import utils_idyntree, utils_models, utils_rng +from .utils_models import Robot @pytest.mark.parametrize( "robot, vel_repr", [ - (utils_models.Robot.DoublePendulum, VelRepr.Inertial), - (utils_models.Robot.DoublePendulum, VelRepr.Body), - (utils_models.Robot.DoublePendulum, VelRepr.Mixed), - (utils_models.Robot.Ur10, VelRepr.Inertial), - (utils_models.Robot.Ur10, VelRepr.Body), - (utils_models.Robot.Ur10, VelRepr.Mixed), - (utils_models.Robot.AnymalC, VelRepr.Inertial), - (utils_models.Robot.AnymalC, VelRepr.Body), - (utils_models.Robot.AnymalC, VelRepr.Mixed), - (utils_models.Robot.Cassie, VelRepr.Inertial), - (utils_models.Robot.Cassie, VelRepr.Body), - (utils_models.Robot.Cassie, VelRepr.Mixed), - # (utils_models.Robot.iCub, VelRepr.Inertial), - # (utils_models.Robot.iCub, VelRepr.Body), - # (utils_models.Robot.iCub, VelRepr.Mixed), + p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"), + p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"), + p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"), + p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), + p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"), + p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"), + p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), + p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"), + p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"), + p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), + p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"), + p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"), ], ) def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: @@ -66,7 +65,7 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: kin_dyn = utils_idyntree.KinDynComputations.build( urdf=pathlib.Path(urdf_file_path), - considered_joints=model_jaxsim.joint_names(), + considered_joints=list(model_jaxsim.joint_names()), vel_repr=vel_repr, gravity=gravity, ) @@ -78,7 +77,7 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: base_velocity=np.array(model_jaxsim.base_velocity()), ) - assert kin_dyn.joint_names() == model_jaxsim.joint_names() + assert kin_dyn.joint_names() == list(model_jaxsim.joint_names()) assert kin_dyn.gravity == pytest.approx(model_jaxsim.physics_model.gravity[0:3]) assert kin_dyn.joint_positions() == pytest.approx(model_jaxsim.joint_positions()) assert kin_dyn.joint_velocities() == pytest.approx(model_jaxsim.joint_velocities()) @@ -102,15 +101,14 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: # Test individual terms of the EoM # ================================ - jit_enabled = True - fn = jax.jit if jit_enabled else lambda x: x - - M_jaxsim = fn(model_jaxsim.free_floating_mass_matrix)() - g_jaxsim = fn(model_jaxsim.free_floating_gravity_forces)() - h_jaxsim = fn(model_jaxsim.free_floating_bias_forces)() - J_jaxsim = np.vstack([link.jacobian() for link in model_jaxsim.links()]) + M_jaxsim = model_jaxsim.free_floating_mass_matrix() + g_jaxsim = model_jaxsim.free_floating_gravity_forces() + J_jaxsim = jnp.vstack([link.jacobian() for link in model_jaxsim.links()]) + h_jaxsim = model_jaxsim.free_floating_bias_forces() + # Support both fixed-base and floating-base models by slicing the first six rows sl = np.s_[0:] if model_jaxsim.floating_base() else np.s_[6:] + assert M_jaxsim[sl, sl] == pytest.approx(M_idt[sl, sl], abs=1e-3) assert g_jaxsim[sl] == pytest.approx(g_idt[sl], abs=1e-3) assert h_jaxsim[sl] == pytest.approx(h_idt[sl], abs=1e-3) @@ -120,13 +118,13 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: # Test the forward dynamics computed with CRB # =========================================== - J_ff = fn(model_jaxsim.generalized_free_floating_jacobian)() - f_ext = fn(model_jaxsim.external_forces)().flatten() - nud = np.hstack(fn(model_jaxsim.forward_dynamics_crb)(tau=tau)) + J_ff = model_jaxsim.generalized_free_floating_jacobian() + f_ext = model_jaxsim.external_forces().flatten() + ν̇ = np.hstack(model_jaxsim.forward_dynamics_crb(tau=tau)) S = np.block( [np.zeros(shape=(model_jaxsim.dofs(), 6)), np.eye(model_jaxsim.dofs())] ).T assert h_jaxsim[sl] == pytest.approx( - (S @ tau + J_ff.T @ f_ext - M_jaxsim @ nud)[sl], abs=1e-3 + (S @ tau + J_ff.T @ f_ext - M_jaxsim @ ν̇)[sl], abs=1e-3 ) diff --git a/tests/test_forward_dynamics.py b/tests/test_forward_dynamics.py index ec115a4c5..b4bb72a37 100644 --- a/tests/test_forward_dynamics.py +++ b/tests/test_forward_dynamics.py @@ -1,31 +1,29 @@ -import jax import numpy as np import pytest +from pytest import param as p from jaxsim.high_level.common import VelRepr from jaxsim.high_level.model import Model from . import utils_models, utils_rng +from .utils_models import Robot @pytest.mark.parametrize( "robot, vel_repr", [ - (utils_models.Robot.DoublePendulum, VelRepr.Inertial), - (utils_models.Robot.DoublePendulum, VelRepr.Body), - (utils_models.Robot.DoublePendulum, VelRepr.Mixed), - (utils_models.Robot.Ur10, VelRepr.Inertial), - (utils_models.Robot.Ur10, VelRepr.Body), - (utils_models.Robot.Ur10, VelRepr.Mixed), - (utils_models.Robot.AnymalC, VelRepr.Inertial), - (utils_models.Robot.AnymalC, VelRepr.Body), - (utils_models.Robot.AnymalC, VelRepr.Mixed), - (utils_models.Robot.Cassie, VelRepr.Inertial), - (utils_models.Robot.Cassie, VelRepr.Body), - (utils_models.Robot.Cassie, VelRepr.Mixed), - # (utils_models.Robot.iCub, VelRepr.Inertial), - # (utils_models.Robot.iCub, VelRepr.Body), - # (utils_models.Robot.iCub, VelRepr.Mixed), + p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"), + p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"), + p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"), + p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), + p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"), + p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"), + p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), + p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"), + p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"), + p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), + p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"), + p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"), ], ) def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None: @@ -61,15 +59,13 @@ def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None: tau = model.joint_generalized_forces_targets() # Compute model acceleration with ABA - jit_enabled = True - fn = jax.jit if jit_enabled else lambda x: x - a_WB_aba, sdd_aba = fn(model.forward_dynamics_aba)(tau=tau) + v̇_WB_aba, s̈_aba = model.forward_dynamics_aba(tau=tau) # ============================================== # Compute forward dynamics with dedicated method # ============================================== - a_WB, sdd = model.forward_dynamics_crb(tau=tau) + v̇_WB, s̈ = model.forward_dynamics_crb(tau=tau) - assert sdd.squeeze() == pytest.approx(sdd_aba.squeeze(), abs=0.5) - assert a_WB.squeeze() == pytest.approx(a_WB_aba.squeeze(), abs=0.2) + assert s̈.squeeze() == pytest.approx(s̈_aba.squeeze(), abs=0.5) + assert v̇_WB.squeeze() == pytest.approx(v̇_WB_aba.squeeze(), abs=0.2) diff --git a/tests/test_jax_oop.py b/tests/test_jax_oop.py new file mode 100644 index 000000000..f887ecfae --- /dev/null +++ b/tests/test_jax_oop.py @@ -0,0 +1,422 @@ +import dataclasses +import io +from contextlib import redirect_stdout +from typing import Any, Type + +import jax +import jax.numpy as jnp +import jax_dataclasses +import numpy as np +import pytest + +from jaxsim.utils import Mutability, Vmappable, oop + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class AlgoData(Vmappable): + """Class storing vmappable data of a given algorithm.""" + + counter: jax.Array = dataclasses.field( + default_factory=lambda: jnp.array(0, dtype=jnp.uint64) + ) + + @classmethod + def build(cls: Type[Self], counter: jax.typing.ArrayLike) -> Self: + """Builder method. Helpful for enforcing type and shape of fields.""" + + # Counter can be int / scalar numpy array / scalar jax array / etc. + if jnp.array(counter).squeeze().size != 1: + raise ValueError("The counter must be a scalar") + + # Create the object enforcing `counter` to be a scalar jax array + data = AlgoData( + counter=jnp.array(counter, dtype=jnp.uint64).squeeze(), + ) + + return data + + +def test_data(): + """Test AlgoData class.""" + + data1 = AlgoData.build(counter=0) + data2 = AlgoData.build(counter=np.array(10)) + data3 = AlgoData.build(counter=jnp.array(50)) + + assert isinstance(data1.counter, jax.Array) and data1.counter.dtype == jnp.uint64 + assert isinstance(data2.counter, jax.Array) and data2.counter.dtype == jnp.uint64 + assert isinstance(data3.counter, jax.Array) and data3.counter.dtype == jnp.uint64 + + assert data1.batch_size == 0 + assert data2.batch_size == 0 + assert data3.batch_size == 0 + + # ================== + # Vectorizing PyTree + # ================== + + for batch_size in (0, 10, 100): + data_vec = data1.vectorize(batch_size=batch_size) + + assert data_vec.batch_size == batch_size + + if batch_size > 0: + assert data_vec.counter.shape[0] == batch_size + + # ========================================= + # Extracting element from vectorized PyTree + # ========================================= + + data_vec = AlgoData.build_from_list(list_of_obj=[data1, data2, data3]) + assert data_vec.batch_size == 3 + assert data_vec.extract_element(index=0) == data1 + assert data_vec.extract_element(index=1) == data2 + assert data_vec.extract_element(index=2) == data3 + + with pytest.raises(ValueError): + _ = data_vec.extract_element(index=3) + + out = data1.extract_element(index=0) + assert out == data1 + assert id(out) != id(data1) + + with pytest.raises(RuntimeError): + _ = data1.extract_element(index=1) + + with pytest.raises(ValueError): + _ = AlgoData.build_from_list(list_of_obj=[data1, data2, data3, 42]) + + +@jax_dataclasses.pytree_dataclass +class MyClassWithAlgorithms(Vmappable): + """ + Class to demonstrate how to use `Vmappable`. + """ + + # Dynamic data of the algorithm + data: AlgoData = dataclasses.field(default=None) + + # Static attribute of the pytree (triggers recompilation if changed) + double_input: jax_dataclasses.Static[bool] = dataclasses.field(default=None) + + # Non-static attribute of the pytree that is not transparently vmap-able. + done: jax.typing.ArrayLike = dataclasses.field( + default_factory=lambda: jnp.array(False, dtype=bool) + ) + + # Additional leaves to test the behaviour of mutable and immutable python objects + my_tuple: tuple[int] = dataclasses.field(default=tuple(jnp.array([1, 2, 3]))) + my_list: list[int] = dataclasses.field( + default_factory=lambda: [4, 5, 6], init=False + ) + my_array: jax.Array = dataclasses.field( + default_factory=lambda: jnp.array([10, 20, 30]) + ) + + @classmethod + def build(cls: Type[Self], double_input: bool = False) -> Self: + """""" + + obj = MyClassWithAlgorithms() + + with obj.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + obj.data = AlgoData.build(counter=0) + obj.double_input = jnp.array(double_input) + + return obj + + @oop.jax_tf.method_ro + def algo_ro(self, advance: int | jax.typing.ArrayLike) -> Any: + """This is a read-only algorithm. It does not alter any pytree leaf.""" + + # This should be printed only the first execution since it is disabled + # in the execution of the JIT-compiled function. + print("__algo_ro__") + + # Use the dynamic condition that doubles the input value + mul = jax.lax.select(self.double_input, 2, 1) + + # Increase the counter + counter_old = jnp.atleast_1d(self.data.counter)[0] + counter_new = counter_old + mul * advance + + # Return the updated counter + return counter_new + + @oop.jax_tf.method_rw + def algo_rw(self, advance: int | jax.typing.ArrayLike) -> Any: + """ + This is a read-write algorithm. It may alter pytree leaves either belonging + to the vmappable data or generic non-static dataclass attributes. + """ + + print(self) + + # This should be printed only the first execution since it is disabled + # in the execution of the JIT-compiled function. + print("__algo_rw__") + + # Use the dynamic condition that doubles the input value + mul = jax.lax.select(self.double_input, 2, 1) + + # Increase the internal counter + counter_old = jnp.atleast_1d(self.data.counter)[0] + self.data.counter = jnp.array(counter_old + mul * advance, dtype=jnp.uint64) + + # Update the non-static and non-vmap-able attribute + self.done = jax.lax.cond( + pred=self.data.counter > 100, + true_fun=lambda _: jnp.array(True), + false_fun=lambda _: jnp.array(False), + operand=None, + ) + + print(self) + + # Return the updated counter + return self.data.counter + + +def test_mutability(): + """Test MyClassWithAlgorithms class.""" + + # Build the object + obj_ro = MyClassWithAlgorithms.build(double_input=True) + + # By default, pytrees built with jax_dataclasses are frozen (read-only) + assert obj_ro._mutability() == Mutability.FROZEN + with pytest.raises(dataclasses.FrozenInstanceError): + obj_ro.data.counter = 42 + + # Data can be changed through a context manager, in this case operating on a copy... + with obj_ro.editable(validate=True) as obj_ro_copy: + obj_ro_copy.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) + assert obj_ro_copy.data.counter == pytest.approx(42) + assert obj_ro.data.counter != pytest.approx(42) + + # ... or a context manager that does not copy the pytree... + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) + assert obj_ro.data.counter == pytest.approx(42) + + # ... that raises if the leafs change type + with pytest.raises(AssertionError): + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.data.counter = 42 + + # Pytrees can be copied... + obj_ro_copy = obj_ro.copy() + assert id(obj_ro) != id(obj_ro_copy) + # ... operation that does not copy the leaves + # TODO describe + assert id(obj_ro.done) == id(obj_ro_copy.done) + assert id(obj_ro.data.counter) == id(obj_ro_copy.data.counter) + assert id(obj_ro.my_array) == id(obj_ro_copy.my_array) + assert id(obj_ro.my_tuple) != id(obj_ro_copy.my_tuple) + assert id(obj_ro.my_list) != id(obj_ro_copy.my_list) + + # They can be converted as mutable pytrees to update their values without + # using context managers (maybe useful for debugging or quick prototyping) + obj_rw = obj_ro.copy().mutable(validate=True) + assert obj_rw._mutability() == Mutability.MUTABLE + obj_rw.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) + + # However, with validation enabled, this works only if the leaf does not + # change its type (shape, dtype, weakness, ...) + with pytest.raises(AssertionError): + obj_rw.data.counter = 100 + with pytest.raises(AssertionError): + obj_rw.data.counter = jnp.array(100, dtype=float) + with pytest.raises(AssertionError): + obj_rw.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) + + # Instead, with validation disabled, the pytree structure can be altered + # (and this might cause JIT recompilations, so use it at your own risk) + obj_rw_noval = obj_ro.copy().mutable(validate=False) + assert obj_rw_noval._mutability() == Mutability.MUTABLE_NO_VALIDATION + obj_rw_noval.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) + + # Now this should work without exceptions + obj_rw_noval.data.counter = 100 + obj_rw_noval.data.counter = jnp.array(100, dtype=float) + obj_rw_noval.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) + + # Build another object and check mutability changes + obj_ro = MyClassWithAlgorithms.build(double_input=True) + assert obj_ro.is_mutable(validate=True) is False + assert obj_ro.is_mutable(validate=False) is False + + obj_rw_val = obj_ro.mutable(validate=True) + assert id(obj_ro) == id(obj_rw_val) + assert obj_rw_val.is_mutable(validate=True) is True + assert obj_rw_val.is_mutable(validate=False) is False + + obj_rw_noval = obj_rw_val.mutable(validate=False) + assert id(obj_rw_noval) == id(obj_rw_val) + assert obj_rw_noval.is_mutable(validate=True) is False + assert obj_rw_noval.is_mutable(validate=False) is True + + # Checking mutable leaves behavior + obj_rw = MyClassWithAlgorithms.build(double_input=True).mutable(validate=True) + obj_rw_copy = obj_rw.copy() + + # Memory of JAX arrays cannot be altered in place so this is safe + obj_rw.my_array = obj_rw.my_array.at[1].set(-20) + assert obj_rw_copy.my_array[1] != -20 + + # Tuples are immutable so this should be safe too + obj_rw.my_tuple = tuple(jnp.array([1, -2, 3])) + assert obj_rw_copy.my_array[1] != -2 + + # Lists are treated as tuples (they are not leaves) but since they are mutable, + # their id changes + obj_rw.my_list[1] = -5 + assert obj_rw_copy.my_list[1] != -5 + + # Check that exceptions in mutable context do not alter the object + obj_ro = MyClassWithAlgorithms.build(double_input=True) + assert obj_ro.data.counter == 0 + assert obj_ro.double_input == jnp.array(True) + + with pytest.raises(RuntimeError): + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.double_input = jnp.array(False, dtype=obj_ro.double_input.dtype) + obj_ro.data.counter = jnp.array(33, dtype=obj_ro.data.counter.dtype) + raise RuntimeError + assert obj_ro.data.counter == 0 + assert obj_ro.double_input == jnp.array(True) + + +def test_decorators_jit_compilation(): + """Test JIT features of MyClassWithAlgorithms class.""" + + obj = MyClassWithAlgorithms.build(double_input=False) + assert obj.data.counter == 0 + assert obj.is_mutable(validate=True) is False + assert obj.is_mutable(validate=False) is False + + # JIT compilation should happen only the first function call. + # We test this by checking that the first execution prints some output. + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" in printed + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + + # JIT compilation should happen only the first function call. + # We test this by checking that the first execution prints some output. + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_rw__" in printed + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_rw__" not in printed + + # Create a new object + obj = MyClassWithAlgorithms.build(double_input=False) + + # New objects should be able to re-use the JIT-compiled functions from other objects + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + # Create a new object + obj = MyClassWithAlgorithms.build(double_input=False) + + # Read-only methods can be called on r/o objects + out = obj.algo_ro(advance=1) + assert out == obj.data.counter + 1 + out = obj.algo_ro(advance=1) + assert out == obj.data.counter + 1 + + # Read-write methods can be called too on r/o objects since they are marked as r/w + out = obj.algo_rw(advance=1) + assert out == 1 + out = obj.algo_rw(advance=1) + assert out == 2 + out = obj.algo_rw(advance=2) + assert out == 4 + + # Create a new object with a different dynamic attribute + obj_dyn = MyClassWithAlgorithms.build(double_input=False).mutable(validate=True) + obj_dyn.done = jnp.array(not obj_dyn.done, dtype=bool) + + # New objects with different dynamic attributes should be able to re-use the + # JIT-compiled functions from other objects + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + # Create a new object with a different static attribute + obj_stat = MyClassWithAlgorithms.build(double_input=True) + + # New objects with different static attributes trigger the recompilation of the + # JIT-compiled functions... + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj_stat.algo_ro(advance=1) + _ = obj_stat.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" in printed + assert "__algo_rw__" in printed + + # ... that are cached as well by jax + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj_stat.algo_ro(advance=1) + _ = obj_stat.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + +def test_decorators_vmap(): + """Test automatic vectorization features of MyClassWithAlgorithms class.""" + + # Create a new object with scalar data + obj = MyClassWithAlgorithms.build(double_input=False) + + # Vectorize the entire object + obj_vec = obj.vectorize(batch_size=10) + assert obj_vec.vectorized is True + assert obj_vec.batch_size == 10 + assert id(obj_vec) != id(obj) + + # Calling methods of vectorized objects with scalar arguments should raise an error + with pytest.raises(ValueError): + _ = obj_vec.algo_ro(advance=1) + with pytest.raises(ValueError): + _ = obj_vec.algo_rw(advance=1) + + # Check that the r/o method provides automatically vectorized output and accepts + # vectorized input + out_vec = obj_vec.algo_ro(advance=jnp.array([1] * obj_vec.batch_size)) + assert out_vec.shape[0] == 10 + assert set(out_vec.tolist()) == {1} + + # Check that the r/w method provides automatically vectorized output and accepts + # vectorized input + out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) + assert out_vec.shape[0] == 10 + assert set(out_vec.tolist()) == {1} + out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) + assert set(out_vec.tolist()) == {2} + + # Extract a single object from the vectorized object + obj = obj_vec.extract_element(index=5) + assert obj.vectorized is False + assert obj.data.counter == obj_vec.data.counter[5]