Skip to content

Commit

Permalink
Compute KinematicGraph transforms specifying generic joint positions
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Apr 12, 2024
1 parent 946ec19 commit 4fc5fa3
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def supported_joint_motion(
S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))

case JointType.F:
raise ValueError("Fixed joints shouldn't be here")
pre_H_suc = jaxlie.SE3.identity()
S = jnp.zeros(shape=(6, 1))

case _:
raise ValueError(joint_type)
Expand Down
88 changes: 73 additions & 15 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,57 @@ class KinematicGraphTransforms:

graph: KinematicGraph

transform_cache: dict[str, npt.NDArray] = dataclasses.field(
_transform_cache: dict[str, npt.NDArray] = dataclasses.field(
default_factory=dict, init=False, repr=False, compare=False
)

_initial_joint_positions: dict[str, float] = dataclasses.field(
init=False, repr=False, compare=False
)

def __post_init__(self) -> None:

super().__setattr__(
"_initial_joint_positions",
{joint.name: joint.initial_position for joint in self.graph.joints},
)

@property
def initial_joint_positions(self) -> npt.NDArray:

return np.atleast_1d(
np.array(list(self._initial_joint_positions.values()))
).astype(float)

@initial_joint_positions.setter
def initial_joint_positions(
self,
positions: npt.NDArray | Sequence,
joint_names: Sequence[str] | None = None,
) -> None:

joint_names = (
joint_names
if joint_names is not None
else list(self._initial_joint_positions.keys())
)

s = np.atleast_1d(np.array(positions).squeeze())

if s.size != len(joint_names):
raise ValueError(s.size, len(joint_names))

for joint_name in joint_names:
if joint_name not in self._initial_joint_positions:
raise ValueError(joint_name)

# Clear transform cache.
self._transform_cache.clear()

# Update initial joint positions.
for joint_name, position in zip(joint_names, s):
self._initial_joint_positions[joint_name] = position

def transform(self, name: str) -> npt.NDArray:
"""
Compute the SE(3) transform of elements belonging to the kinematic graph.
Expand All @@ -591,32 +638,30 @@ def transform(self, name: str) -> npt.NDArray:
"""

# If the transform was already computed, return it.
if name in self.transform_cache:
return self.transform_cache[name]
if name in self._transform_cache:
return self._transform_cache[name]

# If the name is a joint, compute M_H_J transform.
if name in self.graph.joint_names():

# Get the joint.
joint = self.graph.joints_dict[name]

if joint.initial_position != 0.0:
msg = f"Ignoring unsupported initial position of joint '{name}'"
logging.warning(msg=msg)

# Get the transform of the parent link.
M_H_L = self.transform(name=joint.parent.name)

# Rename the pose of the predecessor joint frame w.r.t. its parent link.
L_H_pre = joint.pose

# Compute the joint transform from the predecessor to the successor frame.
# Note: we assume that the joint angle is always 0.
pre_H_J = np.eye(4)
pre_H_J = self.pre_H_suc(
joint_type=joint.jtype,
joint_position=self._initial_joint_positions[joint.name],
)

# Compute the M_H_J transform.
self.transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
return self.transform_cache[name]
self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
return self._transform_cache[name]

# If the name is a link, compute M_H_L transform.
if name in self.graph.link_names():
Expand All @@ -641,8 +686,8 @@ def transform(self, name: str) -> npt.NDArray:
J_H_L = link.pose

# Compute the M_H_L transform.
self.transform_cache[name] = M_H_J @ J_H_L
return self.transform_cache[name]
self._transform_cache[name] = M_H_J @ J_H_L
return self._transform_cache[name]

# It can only be a plain frame.
if name not in self.graph.frame_names():
Expand All @@ -658,8 +703,8 @@ def transform(self, name: str) -> npt.NDArray:
L_H_F = frame.pose

# Compute the M_H_F transform.
self.transform_cache[name] = M_H_L @ L_H_F
return self.transform_cache[name]
self._transform_cache[name] = M_H_L @ L_H_F
return self._transform_cache[name]

def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
"""
Expand All @@ -682,3 +727,16 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
# and i the frame of the desired link|joint|frame.
return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target

@staticmethod
def pre_H_suc(
joint_type: descriptions.JointType | descriptions.JointDescriptor,
joint_position: float | None = None,
) -> npt.NDArray:

import jaxsim.math

return np.array(
jaxsim.math.supported_joint_motion(
joint_type=joint_type, joint_position=joint_position
)[0]
)

0 comments on commit 4fc5fa3

Please sign in to comment.