Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compilation time of RBDAs for models with many DoFs #153

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jax
import jax.numpy as jnp
import numpy as np

import jaxsim.api as js
import jaxsim.typing as jtp
Expand All @@ -30,17 +29,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
# Note: the index of the joint for RBDAs starts from 1, but
# the index for accessing the right element starts from 0.
# Therefore, there is a -1.
return (
jnp.array(
np.argwhere(
np.array(model.kin_dyn_parameters.joint_model.joint_names)
== joint_name
)
- 1
)
.squeeze()
.astype(int)
)
return jnp.array(
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
).squeeze()
return jnp.array(-1).astype(int)


Expand Down
19 changes: 9 additions & 10 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,21 +382,20 @@ def joint_transforms_and_motion_subspaces(
)

# Compute the transforms and motion subspaces of the joints.
# TODO: understand how to use joint_indices to access joint_types, right now
# it fails when used within a JIT context.
pre_H_suc_and_S = [
supported_joint_motion(
joint_type=self.joint_model.joint_types[i + 1],
joint_position=jnp.array(s),
if self.number_of_joints() == 0:
pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
else:
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
jnp.array(self.joint_model.joint_types[1:]).astype(int),
jnp.array(self.joint_model.joint_axis),
jnp.array(joint_positions),
)
for i, s in enumerate(jnp.array(joint_positions).astype(float))
]

# Extract the transforms and motion subspaces of the joints.
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
# for either the fixed or free-floating joint connecting the world to the base.
pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S])
S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S])
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])

# Extract the successor-to-child fixed transforms.
# Note that here we include also the index 0 since suc_H_child[0] stores the
Expand Down
74 changes: 32 additions & 42 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from __future__ import annotations

import functools

import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.parsers.descriptions import (
JointDescriptor,
JointGenericAxis,
JointType,
ModelDescription,
)
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms

from .rotation import Rotation
Expand Down Expand Up @@ -46,7 +39,8 @@ class JointModel:

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[JointType | JointDescriptor, ...]]
joint_types: Static[tuple[JointType, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]

@staticmethod
def build(description: ModelDescription) -> JointModel:
Expand Down Expand Up @@ -114,7 +108,8 @@ def build(description: ModelDescription) -> JointModel:
# Static attributes
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple([j.axis for j in ordered_joints]),
)

def parent_H_child(
Expand Down Expand Up @@ -226,59 +221,54 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
return self.suc_H_i[joint_index]


@functools.partial(jax.jit, static_argnames=["joint_type"])
@jax.jit
def supported_joint_motion(
joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
joint_type: JointType, joint_axis: JointGenericAxis, joint_position: jtp.VectorLike
flferretti marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[jtp.Matrix, jtp.Array]:
"""
Compute the homogeneous transformation and motion subspace of a joint.

Args:
joint_type: The type of the joint.
joint_axis: The axis of rotation or translation of the joint.
joint_position: The position of the joint.

Returns:
A tuple containing the homogeneous transformation and the motion subspace.
"""

if isinstance(joint_type, JointType):
type_enum = joint_type
elif isinstance(joint_type, JointDescriptor):
type_enum = joint_type.joint_type
else:
raise ValueError(joint_type)

# Prepare the joint position
s = jnp.array(joint_position).astype(float)

match type_enum:

case JointType.R:
joint_type: JointGenericAxis
def compute_F():
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_type.axis)
)
def compute_R():
pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_axis)
)
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))

case JointType.P:
joint_type: JointGenericAxis

pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_type.axis),
)
S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
return pre_H_suc, S

S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
def compute_P():
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_axis),
)

case JointType.F:
pre_H_suc = jaxlie.SE3.identity()
S = jnp.zeros(shape=(6, 1))
S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
return pre_H_suc, S

case _:
raise ValueError(joint_type)
pre_H_suc, S = jax.lax.switch(
index=joint_type,
branches=(
compute_F, # JointType.Fixed
compute_R, # JointType.Revolute
compute_P, # JointType.Prismatic
),
)

return pre_H_suc.as_matrix(), S
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
45 changes: 10 additions & 35 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import dataclasses
import enum
from typing import Tuple, Union
from typing import ClassVar, Tuple, Union

import jax_dataclasses
import numpy as np
Expand All @@ -14,39 +13,15 @@
from .link import LinkDescription


@enum.unique
class JointType(enum.IntEnum):
"""
Type of supported joints.
"""

@staticmethod
def _generate_next_value_(name, start, count, last_values):
# Start auto Enum value from 0 instead of 1
return count

#: Fixed joint.
F = enum.auto()

#: Revolute joint (1 DoF around axis).
R = enum.auto()

#: Prismatic joint (1 DoF along axis).
P = enum.auto()


@jax_dataclasses.pytree_dataclass
class JointDescriptor:
"""
Base class for joint types requiring to store additional metadata.
"""

#: The joint type.
joint_type: JointType
@dataclasses.dataclass(frozen=True)
class JointType:
Fixed: ClassVar[int] = 0
Revolute: ClassVar[int] = 1
Prismatic: ClassVar[int] = 2


@jax_dataclasses.pytree_dataclass
class JointGenericAxis(JointDescriptor):
class JointGenericAxis:
"""
A joint requiring the specification of a 3D axis.
"""
Expand All @@ -55,7 +30,7 @@ class JointGenericAxis(JointDescriptor):
axis: jtp.Vector

def __hash__(self) -> int:
return hash((self.joint_type, tuple(np.array(self.axis).tolist())))
return hash((tuple(np.array(self.axis).tolist())))

def __eq__(self, other: JointGenericAxis) -> bool:
if not isinstance(other, JointGenericAxis):
Expand All @@ -73,7 +48,7 @@ class JointDescription(JaxsimDataclass):
name (str): The name of the joint.
axis (npt.NDArray): The axis of rotation or translation for the joint.
pose (npt.NDArray): The pose transformation matrix of the joint.
jtype (Union[JointType, JointDescriptor]): The type of the joint.
jtype (JointType): The type of the joint.
child (LinkDescription): The child link attached to the joint.
parent (LinkDescription): The parent link attached to the joint.
index (Optional[int]): An optional index for the joint.
Expand All @@ -89,7 +64,7 @@ class JointDescription(JaxsimDataclass):
name: jax_dataclasses.Static[str]
axis: npt.NDArray
pose: npt.NDArray
jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]]
jtype: jax_dataclasses.Static[JointType]
child: LinkDescription = dataclasses.dataclass(repr=False)
parent: LinkDescription = dataclasses.dataclass(repr=False)

Expand Down
8 changes: 6 additions & 2 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ def transform(self, name: str) -> npt.NDArray:
# Compute the joint transform from the predecessor to the successor frame.
pre_H_J = self.pre_H_suc(
joint_type=joint.jtype,
joint_axis=joint.axis,
joint_position=self._initial_joint_positions[joint.name],
)

Expand Down Expand Up @@ -762,14 +763,17 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:

@staticmethod
def pre_H_suc(
joint_type: descriptions.JointType | descriptions.JointDescriptor,
joint_type: descriptions.JointType,
joint_axis: descriptions.JointGenericAxis,
joint_position: float | None = None,
) -> npt.NDArray:

import jaxsim.math

return np.array(
jaxsim.math.supported_joint_motion(
joint_type=joint_type, joint_position=joint_position
joint_type=joint_type,
joint_axis=joint_axis,
joint_position=joint_position,
)[0]
)
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def build_model_description(
considered_joints=[
j.name
for j in sdf_data.joint_descriptions
if j.jtype is not descriptions.JointType.F
if j.jtype is not descriptions.JointType.Fixed
],
)

Expand Down
12 changes: 4 additions & 8 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:

def joint_to_joint_type(
joint: rod.Joint,
) -> descriptions.JointType | descriptions.JointDescriptor:
) -> descriptions.JointType:
"""
Extract the joint type from an SDF joint.

Expand All @@ -76,7 +76,7 @@ def joint_to_joint_type(
joint_type = joint.type

if joint_type == "fixed":
return descriptions.JointType.F
return descriptions.JointType.Fixed

if not (axis.xyz is not None and axis.xyz.xyz is not None):
raise ValueError("Failed to read axis xyz data")
Expand All @@ -86,14 +86,10 @@ def joint_to_joint_type(
axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)

if joint_type in {"revolute", "continuous"}:
return descriptions.JointGenericAxis(
joint_type=descriptions.JointType.R, axis=axis_xyz
)
return descriptions.JointType.Revolute

if joint_type == "prismatic":
return descriptions.JointGenericAxis(
joint_type=descriptions.JointType.P, axis=axis_xyz
)
return descriptions.JointType.Prismatic

raise ValueError("Joint not supported", axis_xyz, joint_type)

Expand Down
Loading