From 205b4c9af96743e9f637f8476c0b316637ba98a9 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Wed, 5 Oct 2022 12:11:06 +0200 Subject: [PATCH] Update the parser logic to use the rod library --- src/jaxsim/parsers/sdf/parser.py | 104 ++++++++++++++++++------------- src/jaxsim/parsers/sdf/utils.py | 84 +++++++++---------------- 2 files changed, 91 insertions(+), 97 deletions(-) diff --git a/src/jaxsim/parsers/sdf/parser.py b/src/jaxsim/parsers/sdf/parser.py index fa6fa1feb..1af0f2279 100644 --- a/src/jaxsim/parsers/sdf/parser.py +++ b/src/jaxsim/parsers/sdf/parser.py @@ -1,10 +1,11 @@ import dataclasses +import pathlib from pathlib import Path -from typing import Dict, List, NamedTuple, Union +from typing import Dict, List, NamedTuple, Optional, Union import jax.numpy as jnp import numpy as np -import pysdf +import rod from jaxsim import logging from jaxsim.math.quaternion import Quaternion @@ -24,39 +25,57 @@ class SDFData(NamedTuple): joint_descriptions: List[descriptions.JointDescription] collision_shapes: List[descriptions.CollisionShape] - sdf_tree: pysdf.Model = None + sdf_model: Optional[rod.Model] = None model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose() def extract_data_from_sdf( - sdf: Union[Path, str], + sdf: Union[pathlib.Path, str], model_name: Optional[str] = None ) -> SDFData: - if isinstance(sdf, str) and len(sdf) < 500 and Path(sdf).is_file(): - sdf = Path(sdf) + # Parse the SDF resource + sdf_element = rod.Sdf.load(sdf=sdf) - # Get the SDF string - sdf_string = sdf if isinstance(sdf, str) else sdf.read_text() + if len(sdf_element.models()) == 0: + raise RuntimeError("Failed to find any model in SDF resource") - # Parse the tree - sdf_tree = pysdf.SDF.from_xml(sdf_string=sdf_string, remove_blank_text=True) + # Assume the SDF resource has only one model, or the desired model name is given + sdf_models = {m.name: m for m in sdf_element.models()} + sdf_model = ( + sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name] + ) + logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource") - # Detect whether the model is fixed base by checking joints with world parent exist. - # This link is a special link used to specify that the model's base should be fixed. - fixed_base = len([j for j in sdf_tree.model.joints if j.parent == "world"]) > 0 + # Detect fixed-base models by checking the existence of joints having world as parent + sdf_joints_with_world_parent = [ + j for j in sdf_model.joints() if j.parent == "world" + ] + fixed_base = len(sdf_joints_with_world_parent) > 0 - # Base link of the model. We take the first link in the SDF description. - base_link_name = sdf_tree.model.links[0].name + logging.debug( + msg="Model '{}' is {}".format( + sdf_model.name, "fixed-base" if fixed_base else "floating-base" + ) + ) + + # We extract the link connected to 'world', and consider it as base link. + # Instead, for floating-base models, we consider the first link as base link. + base_link_name = ( + sdf_joints_with_world_parent[0].name + if fixed_base + else sdf_model.links()[0].name + ) + logging.debug(msg=f"Considering '{base_link_name}' as base link") # Pose of the model - if sdf_tree.model.pose is None: + if sdf_model.pose is None: model_pose = kinematic_graph.RootPose() else: - w_H_m = utils.from_sdf_pose(pose=sdf_tree.model.pose) + W_H_M = sdf_model.pose.transform() model_pose = kinematic_graph.RootPose( - root_position=w_H_m[0:3, 3], - root_quaternion=Quaternion.from_dcm(dcm=w_H_m[0:3, 0:3]), + root_position=W_H_M[0:3, 3], + root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]), ) # =========== @@ -69,9 +88,9 @@ def extract_data_from_sdf( name=l.name, mass=jnp.float32(l.inertial.mass), inertia=utils.from_sdf_inertial(inertial=l.inertial), - pose=utils.from_sdf_pose(pose=l.pose) if l.pose is not None else np.eye(4), + pose=l.pose.transform() if l.pose is not None else np.eye(4), ) - for l in sdf_tree.model.links + for l in sdf_model.links() if l.inertial.mass > 0 ] @@ -86,6 +105,7 @@ def extract_data_from_sdf( # to the world and combine their pose if fixed_base: + # Create a massless word link world_link = descriptions.LinkDescription( name="world", mass=0, inertia=np.zeros(shape=(6, 6)) ) @@ -100,20 +120,18 @@ def extract_data_from_sdf( parent=world_link, child=links_dict[j.child], jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), - axis=utils.from_sdf_string_list(string_list=j.axis.xyz.text) + axis=np.array(j.axis.xyz.xyz) if j.axis is not None and j.axis.xyz is not None - and j.axis.xyz.text is not None + and j.axis.xyz.xyz is not None else None, - pose=utils.from_sdf_pose(pose=j.pose) - if j.pose is not None - else np.eye(4), + pose=j.pose.transform() if j.pose is not None else np.eye(4), ) - for j in sdf_tree.model.joints + for j in sdf_model.joints() if j.type == "fixed" and j.parent == "world" and j.child in links_dict.keys() - and j.pose.relative_to == "__model__" + and j.pose.relative_to in {"__model__", None} ] logging.debug( @@ -146,14 +164,14 @@ def extract_data_from_sdf( # ============ # Check that all joint poses are expressed w.r.t. their parent link - for j in sdf_tree.model.joints: + for j in sdf_model.joints(): if j.pose is None: continue if j.parent == "world": - if j.pose.relative_to == "__model__": + if j.pose.relative_to in {"__model__", None}: continue raise ValueError("Pose of fixed joint connecting to 'world' link not valid") @@ -169,12 +187,12 @@ def extract_data_from_sdf( parent=links_dict[j.parent], child=links_dict[j.child], jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), - axis=utils.from_sdf_string_list(j.axis.xyz.text) + axis=np.array(j.axis.xyz.xyz) if j.axis is not None and j.axis.xyz is not None - and j.axis.xyz.text is not None + and j.axis.xyz.xyz is not None else None, - pose=utils.from_sdf_pose(pose=j.pose) if j.pose is not None else np.eye(4), + pose=j.pose.transform() if j.pose is not None else np.eye(4), initial_position=0.0, position_limit=( float(j.axis.limit.lower) @@ -192,7 +210,7 @@ def extract_data_from_sdf( friction_viscous=j.axis.dynamics.damping if j.axis is not None and j.axis.dynamics is not None - and j.axis.dynamics.friction is not None + and j.axis.dynamics.damping is not None else 0.0, position_limit_damper=j.axis.limit.dissipation if j.axis is not None @@ -205,7 +223,7 @@ def extract_data_from_sdf( and j.axis.limit.stiffness is not None else 0.0, ) - for j in sdf_tree.model.joints + for j in sdf_model.joints() if j.type in {"revolute", "prismatic", "fixed"} and j.parent != "world" and j.child in links_dict.keys() @@ -215,7 +233,7 @@ def extract_data_from_sdf( joint_dict = {j.child.name: j.name for j in joints} # Check that all the link poses are expressed wrt their parent joint - for l in sdf_tree.model.links: + for l in sdf_model.links(): if l.name not in links_dict: continue @@ -241,10 +259,10 @@ def extract_data_from_sdf( collisions: List[descriptions.CollisionShape] = [] # Parse the collisions - for link in sdf_tree.model.links: - for collision in link.colliders: + for link in sdf_model.links(): + for collision in link.collisions(): - if collision.geometry.box.to_xml() != "": + if collision.geometry.box is not None: box_collision = utils.create_box_collision( collision=collision, @@ -253,7 +271,7 @@ def extract_data_from_sdf( collisions.append(box_collision) - if collision.geometry.sphere.to_xml() != "": + if collision.geometry.sphere is not None: sphere_collision = utils.create_sphere_collision( collision=collision, @@ -263,14 +281,14 @@ def extract_data_from_sdf( collisions.append(sphere_collision) return SDFData( - model_name=sdf_tree.model.name, + model_name=sdf_model.name, link_descriptions=links, joint_descriptions=joints, collision_shapes=collisions, fixed_base=fixed_base, base_link_name=base_link_name, model_pose=model_pose, - sdf_tree=sdf_tree.model, + sdf_model=sdf_model, ) @@ -298,6 +316,6 @@ def build_model_from_sdf(sdf: Union[Path, str]) -> descriptions.ModelDescription ) # Store the parsed SDF tree as extra info - model = dataclasses.replace(model, extra_info=dict(sdf_tree=sdf_data.sdf_tree)) + model = dataclasses.replace(model, extra_info=dict(sdf_model=sdf_data.sdf_model)) return model diff --git a/src/jaxsim/parsers/sdf/utils.py b/src/jaxsim/parsers/sdf/utils.py index f74ce7551..12e4c7295 100644 --- a/src/jaxsim/parsers/sdf/utils.py +++ b/src/jaxsim/parsers/sdf/utils.py @@ -1,41 +1,18 @@ +import os from typing import Union import jax.numpy as jnp import numpy as np import numpy.typing as npt -import pysdf -from scipy.spatial.transform import Rotation as R +import rod from jaxsim.parsers import descriptions -def from_sdf_string_list(string_list: str, epsilon: float = 1e-06) -> npt.NDArray: - - lst = np.array(string_list.split(" "), dtype=float) - lst[np.abs(lst) < epsilon] = 0 - return lst - - -def from_sdf_pose(pose: pysdf.Pose) -> npt.NDArray: - - # Transform euler to DCM matrix (sequence of extrinsic rotations, i.e. all angles - # consider a fixed reference frame) - DCM = R.from_euler( - seq="xyz", angles=pose.orientation, degrees=pose.degrees - ).as_matrix() - - return np.vstack( - [ - np.hstack([DCM, np.vstack(pose.position)]), - np.array([0, 0, 0, 1]), - ] - ) - - -def from_sdf_inertial(inertial: pysdf.Link.Inertial) -> npt.NDArray: +def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray: from jaxsim.math.inertia import Inertia - from jaxsim.sixd import se3, so3 + from jaxsim.sixd import se3 # Extract the "mass" element m = inertial.mass @@ -43,52 +20,49 @@ def from_sdf_inertial(inertial: pysdf.Link.Inertial) -> npt.NDArray: # Extract the "inertia" element inertia_element = inertial.inertia + ixx = inertia_element.ixx + iyy = inertia_element.iyy + izz = inertia_element.izz + ixy = inertia_element.ixy if inertia_element.ixy is not None else 0.0 + ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0 + iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0 + # Build the 3x3 inertia matrix expressed in the CoM - I_com = np.array( + I_CoM = np.array( [ - [inertia_element.ixx, inertia_element.ixy, inertia_element.ixz], - [inertia_element.ixy, inertia_element.iyy, inertia_element.iyz], - [inertia_element.ixz, inertia_element.iyz, inertia_element.izz], + [ixx, ixy, ixz], + [ixy, iyy, iyz], + [ixz, iyz, izz], ] ) # Build the 6x6 generalized inertia at the CoM - I_generalized = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_com) - - # Transform euler to DCM matrix (sequence of extrinsic rotations, i.e. all angles - # consider a fixed reference frame) - l_R_com = so3.SO3.from_matrix( - R.from_euler( - seq="xyz", angles=inertial.pose.orientation, degrees=inertial.pose.degrees - ).as_matrix() - ) + I_generalized = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM) # Compute the transform from the inertial frame (CoM) to the link frame - l_H_com = se3.SE3.from_rotation_and_translation( - rotation=l_R_com, translation=np.array(inertial.pose.position) - ) + L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4) # We need its inverse - com_H_l = l_H_com.inverse() - com_X_l = com_H_l.adjoint() + CoM_H_L = se3.SE3.from_matrix(matrix=L_H_CoM).inverse() + CoM_X_L: npt.NDArray = CoM_H_L.adjoint() # Express the CoM inertia matrix in the link frame - I_expressed_in_link_frame = com_X_l.T @ I_generalized @ com_X_l + I_expressed_in_link_frame = CoM_X_L.T @ I_generalized @ CoM_X_L return jnp.array(I_expressed_in_link_frame) def axis_to_jtype( - axis: pysdf.Joint.Axis, type: str + axis: rod.Axis, type: str ) -> Union[descriptions.JointType, descriptions.JointDescriptor]: if type == "fixed": return descriptions.JointType.F - if not (axis.xyz is not None and axis.xyz.text is not None): + if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") - axis_xyz = from_sdf_string_list(axis.xyz.text) + axis_xyz = np.array(axis.xyz.xyz) if np.allclose(axis_xyz, [1, 0, 0]) and type in {"revolute", "continuous"}: return descriptions.JointType.Rx @@ -122,7 +96,7 @@ def axis_to_jtype( def create_box_collision( - collision: pysdf.Collision, link_description: descriptions.LinkDescription + collision: rod.Collision, link_description: descriptions.LinkDescription ) -> descriptions.BoxCollision: x, y, z = collision.geometry.box.size @@ -145,7 +119,7 @@ def create_box_collision( - center ) - H = from_sdf_pose(pose=collision.pose) + H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] box_corners_wrt_link = ( @@ -167,7 +141,7 @@ def create_box_collision( def create_sphere_collision( - collision: pysdf.Collision, link_description: descriptions.LinkDescription + collision: rod.Collision, link_description: descriptions.LinkDescription ) -> descriptions.SphereCollision: # From https://stackoverflow.com/a/26127012 @@ -191,9 +165,11 @@ def fibonacci_sphere(samples: int) -> npt.NDArray: return np.vstack(points) r = collision.geometry.sphere.radius - sphere_points = r * fibonacci_sphere(samples=250) + sphere_points = r * fibonacci_sphere( + samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default=250)) + ) - H = from_sdf_pose(pose=collision.pose) + H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1]